fix(promtail): Handle docker logs when a log is split in multiple frames (#12374)

pull/12796/head^2
Jonas L. B 1 year ago committed by GitHub
parent 76ba24e3d8
commit c0113db4e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 1
      LICENSING.md
  2. 2
      clients/pkg/promtail/promtail.go
  3. 162
      clients/pkg/promtail/targets/docker/target.go
  4. 2
      clients/pkg/promtail/targets/docker/target_group.go
  5. 83
      clients/pkg/promtail/targets/docker/target_test.go
  6. 2
      clients/pkg/promtail/targets/docker/targetmanager.go
  7. 1
      clients/pkg/promtail/targets/docker/targetmanager_test.go
  8. 1
      clients/pkg/promtail/targets/docker/testdata/partial-tty.log
  9. BIN
      clients/pkg/promtail/targets/docker/testdata/partial.log
  10. 4
      clients/pkg/promtail/targets/manager.go
  11. 9
      docs/sources/send-data/promtail/configuration.md
  12. 173
      pkg/framedstdcopy/framedstdcopy.go
  13. 269
      pkg/framedstdcopy/framedstdcopy_test.go

@ -10,6 +10,7 @@ The following folders and their subfolders are licensed under Apache-2.0:
```
clients/
pkg/framedstdcopy/
pkg/ingester/wal
pkg/logproto/
pkg/loghttp/

@ -184,7 +184,7 @@ func (p *Promtail) reloadConfig(cfg *config.Config) error {
entryHandlers = append(entryHandlers, p.client)
p.entriesFanout = utils.NewFanoutEntryHandler(timeoutUntilFanoutHardStop, entryHandlers...)
tms, err := targets.NewTargetManagers(p, p.reg, p.logger, cfg.PositionsConfig, p.entriesFanout, cfg.ScrapeConfig, &cfg.TargetConfig, cfg.Global.FileWatch)
tms, err := targets.NewTargetManagers(p, p.reg, p.logger, cfg.PositionsConfig, p.entriesFanout, cfg.ScrapeConfig, &cfg.TargetConfig, cfg.Global.FileWatch, &cfg.LimitsConfig)
if err != nil {
return err
}

@ -1,10 +1,8 @@
package docker
import (
"bufio"
"context"
"fmt"
"io"
"strconv"
"strings"
"sync"
@ -12,7 +10,6 @@ import (
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/stdcopy"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/prometheus/common/model"
@ -24,6 +21,7 @@ import (
"github.com/grafana/loki/v3/clients/pkg/promtail/positions"
"github.com/grafana/loki/v3/clients/pkg/promtail/targets/target"
"github.com/grafana/loki/v3/pkg/framedstdcopy"
"github.com/grafana/loki/v3/pkg/logproto"
)
@ -36,6 +34,7 @@ type Target struct {
labels model.LabelSet
relabelConfig []*relabel.Config
metrics *Metrics
maxLineSize int
cancel context.CancelFunc
client client.APIClient
@ -53,6 +52,7 @@ func NewTarget(
labels model.LabelSet,
relabelConfig []*relabel.Config,
client client.APIClient,
maxLineSize int,
) (*Target, error) {
pos, err := position.Get(positions.CursorKey(containerName))
@ -73,6 +73,7 @@ func NewTarget(
labels: labels,
relabelConfig: relabelConfig,
metrics: metrics,
maxLineSize: maxLineSize,
client: client,
running: atomic.NewBool(false),
@ -109,22 +110,22 @@ func (t *Target) processLoop(ctx context.Context) {
}
// Start transferring
rstdout, wstdout := io.Pipe()
rstderr, wstderr := io.Pipe()
cstdout := make(chan []byte)
cstderr := make(chan []byte)
t.wg.Add(1)
go func() {
defer func() {
t.wg.Done()
wstdout.Close()
wstderr.Close()
close(cstdout)
close(cstderr)
t.Stop()
}()
var written int64
var err error
if inspectInfo.Config.Tty {
written, err = io.Copy(wstdout, logs)
written, err = framedstdcopy.NoHeaderFramedStdCopy(cstdout, logs)
} else {
written, err = stdcopy.StdCopy(wstdout, wstderr, logs)
written, err = framedstdcopy.FramedStdCopy(cstdout, cstderr, logs)
}
if err != nil {
level.Warn(t.logger).Log("msg", "could not transfer logs", "written", written, "container", t.containerName, "err", err)
@ -135,8 +136,8 @@ func (t *Target) processLoop(ctx context.Context) {
// Start processing
t.wg.Add(2)
go t.process(rstdout, "stdout")
go t.process(rstderr, "stderr")
go t.process(cstdout, "stdout")
go t.process(cstderr, "stderr")
// Wait until done
<-ctx.Done()
@ -149,81 +150,120 @@ func (t *Target) processLoop(ctx context.Context) {
func extractTs(line string) (time.Time, string, error) {
pair := strings.SplitN(line, " ", 2)
if len(pair) != 2 {
return time.Now(), line, fmt.Errorf("Could not find timestamp in '%s'", line)
return time.Now(), line, fmt.Errorf("could not find timestamp in '%s'", line)
}
ts, err := time.Parse("2006-01-02T15:04:05.999999999Z07:00", pair[0])
if err != nil {
return time.Now(), line, fmt.Errorf("Could not parse timestamp from '%s': %w", pair[0], err)
return time.Now(), line, fmt.Errorf("could not parse timestamp from '%s': %w", pair[0], err)
}
return ts, pair[1], nil
}
// https://devmarkpro.com/working-big-files-golang
func readLine(r *bufio.Reader) (string, error) {
func (t *Target) process(frames chan []byte, logStream string) {
defer func() {
t.wg.Done()
}()
var (
isPrefix = true
err error
line, ln []byte
sizeLimit = t.maxLineSize
discardRemainingLine = false
payloadAcc strings.Builder
curTs = time.Now()
)
for isPrefix && err == nil {
line, isPrefix, err = r.ReadLine()
ln = append(ln, line...)
// If max_line_size is disabled (set to 0), we can in theory have infinite buffer growth.
// We can't guarantee that there's any bound on Docker logs, they could be an infinite stream
// without newlines for all we know. To protect promtail from OOM in that case, we introduce
// this safety limit into the Docker target, inspired by the default Loki max_line_size value:
// https://grafana.com/docs/loki/latest/configure/#limits_config
if sizeLimit == 0 {
sizeLimit = 256 * 1024
}
return string(ln), err
}
func (t *Target) process(r io.Reader, logStream string) {
defer func() {
t.wg.Done()
}()
reader := bufio.NewReader(r)
for {
line, err := readLine(reader)
for frame := range frames {
// Split frame into timestamp and payload
ts, payload, err := extractTs(string(frame))
if err != nil {
if err == io.EOF {
break
if payloadAcc.Len() == 0 {
// If we are currently accumulating a line split over multiple frames, we would still expect
// timestamps in every frame, but since we don't use those secondary ones, we don't log an error in that case.
level.Error(t.logger).Log("msg", "error reading docker log line, skipping line", "err", err)
t.metrics.dockerErrors.Inc()
continue
}
level.Error(t.logger).Log("msg", "error reading docker log line, skipping line", "err", err)
t.metrics.dockerErrors.Inc()
ts = curTs
}
ts, line, err := extractTs(line)
if err != nil {
level.Error(t.logger).Log("msg", "could not extract timestamp, skipping line", "err", err)
t.metrics.dockerErrors.Inc()
// If time has changed, we are looking at a new event (although we should have seen a new line..),
// so flush the buffer if we have one.
if ts != curTs {
discardRemainingLine = false
if payloadAcc.Len() > 0 {
t.handleOutput(logStream, curTs, payloadAcc.String())
payloadAcc.Reset()
}
}
// Check if we have the end of the event
var isEol = strings.HasSuffix(payload, "\n")
// If we are currently discarding a line (due to size limits), skip ahead, but don't skip the next
// frame if we saw the end of the line.
if discardRemainingLine {
discardRemainingLine = !isEol
continue
}
// Add all labels from the config, relabel and filter them.
lb := labels.NewBuilder(nil)
for k, v := range t.labels {
lb.Set(string(k), string(v))
// Strip newline ending if we have it
payload = strings.TrimRight(payload, "\r\n")
// Fast path: Most log lines are a single frame. If we have a full line in frame and buffer is empty,
// then don't use the buffer at all.
if payloadAcc.Len() == 0 && isEol {
t.handleOutput(logStream, ts, payload)
continue
}
lb.Set(dockerLabelLogStream, logStream)
processed, _ := relabel.Process(lb.Labels(), t.relabelConfig...)
filtered := make(model.LabelSet)
for _, lbl := range processed {
if strings.HasPrefix(lbl.Name, "__") {
continue
}
filtered[model.LabelName(lbl.Name)] = model.LabelValue(lbl.Value)
// Add to buffer
payloadAcc.WriteString(payload)
curTs = ts
// Send immediately if line ended or we built a very large event
if isEol || payloadAcc.Len() > sizeLimit {
discardRemainingLine = !isEol
t.handleOutput(logStream, curTs, payloadAcc.String())
payloadAcc.Reset()
}
}
}
func (t *Target) handleOutput(logStream string, ts time.Time, payload string) {
// Add all labels from the config, relabel and filter them.
lb := labels.NewBuilder(nil)
for k, v := range t.labels {
lb.Set(string(k), string(v))
}
lb.Set(dockerLabelLogStream, logStream)
processed, _ := relabel.Process(lb.Labels(), t.relabelConfig...)
t.handler.Chan() <- api.Entry{
Labels: filtered,
Entry: logproto.Entry{
Timestamp: ts,
Line: line,
},
filtered := make(model.LabelSet)
for _, lbl := range processed {
if strings.HasPrefix(lbl.Name, "__") {
continue
}
t.metrics.dockerEntries.Inc()
t.positions.Put(positions.CursorKey(t.containerName), ts.Unix())
t.since = ts.Unix()
filtered[model.LabelName(lbl.Name)] = model.LabelValue(lbl.Value)
}
t.handler.Chan() <- api.Entry{
Labels: filtered,
Entry: logproto.Entry{
Timestamp: ts,
Line: payload,
},
}
t.metrics.dockerEntries.Inc()
t.positions.Put(positions.CursorKey(t.containerName), ts.Unix())
t.since = ts.Unix()
}
// startIfNotRunning starts processing container logs. The operation is idempotent , i.e. the processing cannot be started twice.

@ -36,6 +36,7 @@ type targetGroup struct {
httpClientConfig config.HTTPClientConfig
client client.APIClient
refreshInterval model.Duration
maxLineSize int
mtx sync.Mutex
targets map[string]*Target
@ -120,6 +121,7 @@ func (tg *targetGroup) addTarget(id string, discoveredLabels model.LabelSet) err
discoveredLabels.Merge(tg.defaultLabels),
tg.relabelConfig,
tg.client,
tg.maxLineSize,
)
if err != nil {
return err

@ -23,16 +23,23 @@ import (
"github.com/grafana/loki/v3/clients/pkg/promtail/positions"
)
func Test_DockerTarget(t *testing.T) {
h := func(w http.ResponseWriter, r *http.Request) {
type urlContainToPath struct {
contains string
filePath string
}
func handlerForPath(t *testing.T, paths []urlContainToPath, tty bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch path := r.URL.Path; {
case strings.HasSuffix(path, "/logs"):
var filePath string
if strings.Contains(r.URL.RawQuery, "since=0") {
filePath = "testdata/flog.log"
} else {
filePath = "testdata/flog_after_restart.log"
for _, cf := range paths {
if strings.Contains(r.URL.RawQuery, cf.contains) {
filePath = cf.filePath
break
}
}
assert.NotEmpty(t, filePath, "Did not find appropriate filePath to serve request")
dat, err := os.ReadFile(filePath)
require.NoError(t, err)
_, err = w.Write(dat)
@ -42,15 +49,19 @@ func Test_DockerTarget(t *testing.T) {
info := types.ContainerJSON{
ContainerJSONBase: &types.ContainerJSONBase{},
Mounts: []types.MountPoint{},
Config: &container.Config{Tty: false},
Config: &container.Config{Tty: tty},
NetworkSettings: &types.NetworkSettings{},
}
err := json.NewEncoder(w).Encode(info)
require.NoError(t, err)
}
}
})
}
ts := httptest.NewServer(http.HandlerFunc(h))
func Test_DockerTarget(t *testing.T) {
h := handlerForPath(t, []urlContainToPath{{"since=0", "testdata/flog.log"}, {"", "testdata/flog_after_restart.log"}}, false)
ts := httptest.NewServer(h)
defer ts.Close()
w := log.NewSyncWriter(os.Stderr)
@ -74,6 +85,7 @@ func Test_DockerTarget(t *testing.T) {
model.LabelSet{"job": "docker"},
[]*relabel.Config{},
client,
0,
)
require.NoError(t, err)
@ -105,6 +117,59 @@ func Test_DockerTarget(t *testing.T) {
}, 5*time.Second, 100*time.Millisecond, "Expected log lines after restart were not found within the time limit.")
}
func doTestPartial(t *testing.T, tty bool) {
var filePath string
if tty {
filePath = "testdata/partial-tty.log"
} else {
filePath = "testdata/partial.log"
}
h := handlerForPath(t, []urlContainToPath{{"", filePath}}, tty)
ts := httptest.NewServer(h)
defer ts.Close()
w := log.NewSyncWriter(os.Stderr)
logger := log.NewLogfmtLogger(w)
entryHandler := fake.New(func() {})
client, err := client.NewClientWithOpts(client.WithHost(ts.URL))
require.NoError(t, err)
ps, err := positions.New(logger, positions.Config{
SyncPeriod: 10 * time.Second,
PositionsFile: t.TempDir() + "/positions.yml",
})
require.NoError(t, err)
target, err := NewTarget(
NewMetrics(prometheus.NewRegistry()),
logger,
entryHandler,
ps,
"flog",
model.LabelSet{"job": "docker"},
[]*relabel.Config{},
client,
0,
)
require.NoError(t, err)
expectedLines := []string{strings.Repeat("a", 16385)}
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assertExpectedLog(c, entryHandler, expectedLines)
}, 10*time.Second, 100*time.Millisecond, "Expected log lines were not found within the time limit.")
target.Stop()
entryHandler.Clear()
}
func Test_DockerTargetPartial(t *testing.T) {
doTestPartial(t, false)
}
func Test_DockerTargetPartialTty(t *testing.T) {
doTestPartial(t, true)
}
// assertExpectedLog will verify that all expectedLines were received, in any order, without duplicates.
func assertExpectedLog(c *assert.CollectT, entryHandler *fake.Client, expectedLines []string) {
logLines := entryHandler.Received()

@ -43,6 +43,7 @@ func NewTargetManager(
positions positions.Positions,
pushClient api.EntryHandler,
scrapeConfigs []scrapeconfig.Config,
maxLineSize int,
) (*TargetManager, error) {
noopRegistry := util.NoopRegistry{}
noopSdMetrics, err := discovery.CreateAndRegisterSDMetrics(noopRegistry)
@ -94,6 +95,7 @@ func NewTargetManager(
host: sdConfig.Host,
httpClientConfig: sdConfig.HTTPClientConfig,
refreshInterval: sdConfig.RefreshInterval,
maxLineSize: maxLineSize,
}
}
configs[syncerKey] = append(configs[syncerKey], sdConfig)

@ -95,6 +95,7 @@ func Test_TargetManager(t *testing.T) {
ps,
entryHandler,
cfgs,
0,
)
require.NoError(t, err)
require.True(t, ta.Ready())

File diff suppressed because one or more lines are too long

@ -9,6 +9,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/grafana/loki/v3/clients/pkg/promtail/api"
"github.com/grafana/loki/v3/clients/pkg/promtail/limit"
"github.com/grafana/loki/v3/clients/pkg/promtail/positions"
"github.com/grafana/loki/v3/clients/pkg/promtail/scrapeconfig"
"github.com/grafana/loki/v3/clients/pkg/promtail/targets/azureeventhubs"
@ -76,6 +77,7 @@ func NewTargetManagers(
scrapeConfigs []scrapeconfig.Config,
targetConfig *file.Config,
watchConfig file.WatchConfig,
limitsConfig *limit.Config,
) (*TargetManagers, error) {
if targetConfig.Stdin {
level.Debug(logger).Log("msg", "configured to read from stdin")
@ -273,7 +275,7 @@ func NewTargetManagers(
if err != nil {
return nil, err
}
cfTargetManager, err := docker.NewTargetManager(dockerMetrics, logger, pos, client, scrapeConfigs)
cfTargetManager, err := docker.NewTargetManager(dockerMetrics, logger, pos, client, scrapeConfigs, limitsConfig.MaxLineSize.Val())
if err != nil {
return nil, errors.Wrap(err, "failed to make Docker service discovery target manager")
}

@ -1959,6 +1959,13 @@ or [journald](https://docs.docker.com/config/containers/logging/journald/) loggi
Note that the discovery will not pick up finished containers. That means
Promtail will not scrape the remaining logs from finished containers after a restart.
The Docker target correctly joins log segments if a long line was split into different frames by Docker.
To avoid hypothetically unlimited line size and out-of-memory errors in Promtail, this target applies
a default soft line size limit of 256 kiB corresponding to the default max line size in Loki.
If the buffer increases above this size, then the line will be sent to output immediately, and the rest
of the line discarded. To change this behaviour, set `limits_config.max_line_size` to a non-zero value
to apply a hard limit.
The configuration is inherited from [Prometheus' Docker service discovery](https://prometheus.io/docs/prometheus/latest/configuration/configuration/#docker_sd_config).
```yaml
@ -2084,6 +2091,8 @@ The optional `limits_config` block configures global limits for this instance of
[max_streams: <int> | default = 0]
# Maximum log line byte size allowed without dropping. Example: 256kb, 2M. 0 to disable.
# If disabled, targets may apply default buffer size safety limits. If a target implements
# a default limit, this will be documented under the `scrape_configs` entry.
[max_line_size: <int> | default = 0]
# Whether to truncate lines that exceed max_line_size. No effect if max_line_size is disabled
[max_line_size_truncate: <bool> | default = false]

@ -0,0 +1,173 @@
package framedstdcopy
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"github.com/docker/docker/pkg/stdcopy"
)
const (
// From stdcopy
stdWriterPrefixLen = 8
stdWriterFdIndex = 0
stdWriterSizeIndex = 4
startingBufLen = 32*1024 + stdWriterPrefixLen + 1
maxFrameLen = 16384 + 31 // In practice (undocumented) frame payload can be timestamp + 16k
)
// FramedStdCopy is a modified version of stdcopy.StdCopy.
// FramedStdCopy will demultiplex `src` in the same manner as StdCopy, but instead of
// using io.Writer for outputs, channels are used, since each frame payload may contain
// its own inner header (notably, timestamps). Frame payloads are not further parsed here,
// but are passed raw as individual slices through the output channel.
//
// FramedStdCopy will read until it hits EOF on `src`. It will then return a nil error.
// In other words: if `err` is non nil, it indicates a real underlying error.
//
// `written` will hold the total number of bytes written to `dstout` and `dsterr`.
func FramedStdCopy(dstout, dsterr chan []byte, src io.Reader) (written int64, err error) {
var (
buf = make([]byte, startingBufLen)
bufLen = len(buf)
nr int
er error
out chan []byte
frameSize int
)
for {
// Make sure we have at least a full header
for nr < stdWriterPrefixLen {
var nr2 int
nr2, er = src.Read(buf[nr:])
nr += nr2
if er == io.EOF {
if nr < stdWriterPrefixLen {
return written, nil
}
break
}
if er != nil {
return 0, er
}
}
stream := stdcopy.StdType(buf[stdWriterFdIndex])
// Check the first byte to know where to write
switch stream {
case stdcopy.Stdin:
fallthrough
case stdcopy.Stdout:
// Write on stdout
out = dstout
case stdcopy.Stderr:
// Write on stderr
out = dsterr
case stdcopy.Systemerr:
// If we're on Systemerr, we won't write anywhere.
// NB: if this code changes later, make sure you don't try to write
// to outstream if Systemerr is the stream
out = nil
default:
return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
}
// Retrieve the size of the frame
frameSize = int(binary.BigEndian.Uint32(buf[stdWriterSizeIndex : stdWriterSizeIndex+4]))
// Check if the buffer is big enough to read the frame.
// Extend it if necessary.
if frameSize+stdWriterPrefixLen > bufLen {
buf = append(buf, make([]byte, frameSize+stdWriterPrefixLen-bufLen+1)...)
bufLen = len(buf)
}
// While the amount of bytes read is less than the size of the frame + header, we keep reading
for nr < frameSize+stdWriterPrefixLen {
var nr2 int
nr2, er = src.Read(buf[nr:])
nr += nr2
if er == io.EOF {
if nr < frameSize+stdWriterPrefixLen {
return written, nil
}
break
}
if er != nil {
return 0, er
}
}
// we might have an error from the source mixed up in our multiplexed
// stream. if we do, return it.
if stream == stdcopy.Systemerr {
return written, fmt.Errorf("error from daemon in stream: %s", string(buf[stdWriterPrefixLen:frameSize+stdWriterPrefixLen]))
}
// Write the retrieved frame (without header)
var newBuf = make([]byte, frameSize)
copy(newBuf, buf[stdWriterPrefixLen:])
out <- newBuf
written += int64(frameSize)
// Move the rest of the buffer to the beginning
copy(buf, buf[frameSize+stdWriterPrefixLen:nr])
// Move the index
nr -= frameSize + stdWriterPrefixLen
}
}
// Specialized version of FramedStdCopy for when frames have no headers.
// This will happen for output from a container that has TTY set.
// In theory this makes it impossible to find the frame boundaries, which also does not matter if timestamps were not requested,
// but if they were requested, they will still be there at the start of every frame, which might be mid-line.
// In practice we can find most boundaries by looking for newlines, since these result in a new frame.
// Otherwise we rely on using the same max frame size as used in practice by docker.
func NoHeaderFramedStdCopy(dstout chan []byte, src io.Reader) (written int64, err error) {
var (
buf = make([]byte, 32768)
nrLine int
nr int
nr2 int
er error
)
for {
nr2, er = src.Read(buf[nr:])
if er == io.EOF && nr2 == 0 {
return written, nil
} else if er != nil {
return written, er
}
nr += nr2
// We might have read multiple frames, output all those we find in the buffer
for nr > 0 {
nrLine = bytes.Index(buf[:nr], []byte("\n")) + 1
if nrLine > maxFrameLen {
// we found a newline but it's in the next frame (most likely)
nrLine = maxFrameLen
} else if nrLine < 1 {
if nr >= maxFrameLen {
nrLine = maxFrameLen
} else {
// no end of frame found and we don't have enough bytes
break
}
}
// Write the frame
var newBuf = make([]byte, nrLine)
copy(newBuf, buf)
dstout <- newBuf
written += int64(nrLine)
// Move the rest of the buffer to the beginning
copy(buf, buf[nrLine:nr])
// Move the index
nr -= nrLine
}
}
}

@ -0,0 +1,269 @@
package framedstdcopy
import (
"bytes"
"errors"
"io"
"strings"
"sync"
"testing"
"github.com/docker/docker/pkg/stdcopy"
)
const (
tsPrefix string = "2024-03-14T15:32:05.358979323Z "
unprefixedFramePayloadSize int = 16384
)
func timestamped(bytes []byte) []byte {
var ts = []byte(tsPrefix)
return append(ts, bytes...)
}
func getSrcBuffer(stdOutFrames, stdErrFrames [][]byte) (buffer *bytes.Buffer, err error) {
buffer = new(bytes.Buffer)
dstOut := stdcopy.NewStdWriter(buffer, stdcopy.Stdout)
for _, stdOutBytes := range stdOutFrames {
_, err = dstOut.Write(timestamped(stdOutBytes))
if err != nil {
return
}
}
dstErr := stdcopy.NewStdWriter(buffer, stdcopy.Stderr)
for _, stdErrBytes := range stdErrFrames {
_, err = dstErr.Write(timestamped(stdErrBytes))
if err != nil {
return
}
}
return
}
type streamChans struct {
out chan []byte
err chan []byte
outCollected [][]byte
errCollected [][]byte
wg sync.WaitGroup
}
func newChans() streamChans {
return streamChans{
out: make(chan []byte),
err: make(chan []byte),
outCollected: make([][]byte, 0),
errCollected: make([][]byte, 0),
}
}
func (crx *streamChans) collectFrames() {
crx.wg.Add(1)
outClosed := false
errClosed := false
for {
if outClosed && errClosed {
crx.wg.Done()
return
}
select {
case bytes, ok := <-crx.out:
outClosed = !ok
if bytes != nil {
crx.outCollected = append(crx.outCollected, bytes)
}
case bytes, ok := <-crx.err:
errClosed = !ok
if bytes != nil {
crx.errCollected = append(crx.errCollected, bytes)
}
}
}
}
func (crx *streamChans) close() {
close(crx.out)
close(crx.err)
}
func TestStdCopyWriteAndRead(t *testing.T) {
ostr := strings.Repeat("o", unprefixedFramePayloadSize)
estr := strings.Repeat("e", unprefixedFramePayloadSize)
buffer, err := getSrcBuffer(
[][]byte{
[]byte(ostr),
[]byte(ostr[:3] + "\n"),
},
[][]byte{
[]byte(estr),
[]byte(estr[:3] + "\n"),
},
)
if err != nil {
t.Fatal(err)
}
rx := newChans()
go rx.collectFrames()
written, err := FramedStdCopy(rx.out, rx.err, buffer)
rx.close()
rx.wg.Wait()
if err != nil {
t.Fatal(err)
}
tslen := len(tsPrefix)
expectedTotalWritten := 2*maxFrameLen + 2*(4+tslen)
if written != int64(expectedTotalWritten) {
t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
}
if !bytes.Equal(rx.outCollected[0][tslen:maxFrameLen], []byte(ostr)) {
t.Fatal("Expected the first out frame to be all 'o'")
}
if !bytes.Equal(rx.outCollected[1][tslen:tslen+4], []byte("ooo\n")) {
t.Fatal("Expected the second out frame to be 'ooo\\n'")
}
if !bytes.Equal(rx.errCollected[0][tslen:maxFrameLen], []byte(estr)) {
t.Fatal("Expected the first err frame to be all 'e'")
}
if !bytes.Equal(rx.errCollected[1][tslen:tslen+4], []byte("eee\n")) {
t.Fatal("Expected the second err frame to be 'eee\\n'")
}
}
type customReader struct {
n int
err error
totalCalls int
correctCalls int
src *bytes.Buffer
}
func (f *customReader) Read(buf []byte) (int, error) {
f.totalCalls++
if f.totalCalls <= f.correctCalls {
return f.src.Read(buf)
}
return f.n, f.err
}
func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
expectedError := errors.New("error")
reader := &customReader{
err: expectedError,
}
discard := newChans()
go discard.collectFrames()
written, err := FramedStdCopy(discard.out, discard.err, reader)
discard.close()
if written != 0 {
t.Fatalf("Expected 0 bytes read, got %d", written)
}
if err != expectedError {
t.Fatalf("Didn't get expected error")
}
}
func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
expectedError := errors.New("error")
stdOutBytes := []byte(strings.Repeat("o", unprefixedFramePayloadSize))
stdErrBytes := []byte(strings.Repeat("e", unprefixedFramePayloadSize))
buffer, err := getSrcBuffer([][]byte{stdOutBytes}, [][]byte{stdErrBytes})
if err != nil {
t.Fatal(err)
}
reader := &customReader{
correctCalls: 1,
n: stdWriterPrefixLen + 1,
err: expectedError,
src: buffer,
}
discard := newChans()
go discard.collectFrames()
written, err := FramedStdCopy(discard.out, discard.err, reader)
discard.close()
if written != 0 {
t.Fatalf("Expected 0 bytes read, got %d", written)
}
if err != expectedError {
t.Fatalf("Didn't get expected error")
}
}
func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
stdOutBytes := []byte(strings.Repeat("o", unprefixedFramePayloadSize))
stdErrBytes := []byte(strings.Repeat("e", unprefixedFramePayloadSize))
buffer, err := getSrcBuffer([][]byte{stdOutBytes}, [][]byte{stdErrBytes})
if err != nil {
t.Fatal(err)
}
reader := &customReader{
correctCalls: 1,
n: stdWriterPrefixLen + 1,
err: io.EOF,
src: buffer,
}
discard := newChans()
go discard.collectFrames()
written, err := FramedStdCopy(discard.out, discard.err, reader)
discard.close()
if written != maxFrameLen {
t.Fatalf("Expected %d bytes read, got %d", 0, written)
}
if err != nil {
t.Fatal("Didn't get nil error")
}
}
func TestStdCopyWithInvalidInputHeader(t *testing.T) {
dst := newChans()
go dst.collectFrames()
src := strings.NewReader("Invalid input")
_, err := FramedStdCopy(dst.out, dst.err, src)
dst.close()
if err == nil {
t.Fatal("FramedStdCopy with invalid input header should fail.")
}
}
func TestStdCopyWithCorruptedPrefix(t *testing.T) {
data := []byte{0x01, 0x02, 0x03}
src := bytes.NewReader(data)
written, err := FramedStdCopy(nil, nil, src)
if err != nil {
t.Fatalf("FramedStdCopy should not return an error with corrupted prefix.")
}
if written != 0 {
t.Fatalf("FramedStdCopy should have written 0, but has written %d", written)
}
}
// TestStdCopyReturnsErrorFromSystem tests that FramedStdCopy correctly returns an
// error, when that error is muxed into the Systemerr stream.
func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
// write in the basic messages, just so there's some fluff in there
stdOutBytes := []byte(strings.Repeat("o", unprefixedFramePayloadSize))
stdErrBytes := []byte(strings.Repeat("e", unprefixedFramePayloadSize))
buffer, err := getSrcBuffer([][]byte{stdOutBytes}, [][]byte{stdErrBytes})
if err != nil {
t.Fatal(err)
}
// add in an error message on the Systemerr stream
systemErrBytes := []byte(strings.Repeat("S", unprefixedFramePayloadSize))
systemWriter := stdcopy.NewStdWriter(buffer, stdcopy.Systemerr)
_, err = systemWriter.Write(systemErrBytes)
if err != nil {
t.Fatal(err)
}
// now copy and demux. we should expect an error containing the string we
// wrote out
discard := newChans()
go discard.collectFrames()
_, err = FramedStdCopy(discard.out, discard.err, buffer)
discard.close()
if err == nil {
t.Fatal("expected error, got none")
}
if !strings.Contains(err.Error(), string(systemErrBytes)) {
t.Fatal("expected error to contain message")
}
}
Loading…
Cancel
Save