mirror of https://github.com/grafana/loki
Improve logql query statistics collection. (#1573)
* Improve logql query statistics collection. This also add information about ingester queries. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Improve documentation. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Fixes bad copy/past in the result log. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Fixes ingester tests. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Improve headchunk efficiency. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Fix bad commit on master. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Fixes new interface of ingester grpc server. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com> * Improve documentations of fields. Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>pull/1582/head
parent
887f0cce7b
commit
f1b8d4d8ad
@ -1,42 +0,0 @@ |
||||
package decompression |
||||
|
||||
import ( |
||||
"context" |
||||
"time" |
||||
) |
||||
|
||||
type ctxKeyType string |
||||
|
||||
const ctxKey ctxKeyType = "decompression" |
||||
|
||||
// Stats is decompression statistic
|
||||
type Stats struct { |
||||
BytesDecompressed int64 // Total bytes decompressed data size
|
||||
BytesCompressed int64 // Total bytes compressed read
|
||||
FetchedChunks int64 // Total number of chunks fetched.
|
||||
TotalDuplicates int64 // Total number of line duplicates from replication.
|
||||
TimeFetching time.Duration // Time spent fetching chunks.
|
||||
} |
||||
|
||||
// NewContext creates a new decompression context
|
||||
func NewContext(ctx context.Context) context.Context { |
||||
return context.WithValue(ctx, ctxKey, &Stats{}) |
||||
} |
||||
|
||||
// GetStats returns decompression statistics from a context.
|
||||
func GetStats(ctx context.Context) Stats { |
||||
d, ok := ctx.Value(ctxKey).(*Stats) |
||||
if !ok { |
||||
return Stats{} |
||||
} |
||||
return *d |
||||
} |
||||
|
||||
// Mutate mutates the current context statistic using a mutator function
|
||||
func Mutate(ctx context.Context, mutator func(m *Stats)) { |
||||
d, ok := ctx.Value(ctxKey).(*Stats) |
||||
if !ok { |
||||
return |
||||
} |
||||
mutator(d) |
||||
} |
@ -0,0 +1,193 @@ |
||||
/* |
||||
Package stats provides primitives for recording metrics across the query path. |
||||
Statistics are passed through the query context. |
||||
To start a new query statistics context use: |
||||
|
||||
ctx := stats.NewContext(ctx) |
||||
|
||||
Then you can update statistics by mutating data by using: |
||||
|
||||
stats.GetChunkData(ctx) |
||||
stats.GetIngesterData(ctx) |
||||
stats.GetStoreData |
||||
|
||||
Finally to get a snapshot of the current query statistic use |
||||
|
||||
stats.Snapshot(ctx,time.Since(start)) |
||||
|
||||
Ingester statistics are sent across the GRPC stream using Trailers |
||||
see https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md
|
||||
*/ |
||||
package stats |
||||
|
||||
import ( |
||||
"context" |
||||
"time" |
||||
|
||||
"github.com/dustin/go-humanize" |
||||
"github.com/go-kit/kit/log" |
||||
"github.com/go-kit/kit/log/level" |
||||
) |
||||
|
||||
type ctxKeyType string |
||||
|
||||
const ( |
||||
trailersKey ctxKeyType = "trailers" |
||||
chunksKey ctxKeyType = "chunks" |
||||
ingesterKey ctxKeyType = "ingester" |
||||
storeKey ctxKeyType = "store" |
||||
) |
||||
|
||||
// Result contains LogQL query statistics.
|
||||
type Result struct { |
||||
Ingester Ingester |
||||
Store Store |
||||
Summary Summary |
||||
} |
||||
|
||||
// Log logs a query statistics result.
|
||||
func Log(log log.Logger, r Result) { |
||||
level.Debug(log).Log( |
||||
"Ingester.TotalReached", r.Ingester.TotalReached, |
||||
"Ingester.TotalChunksMatched", r.Ingester.TotalChunksMatched, |
||||
"Ingester.TotalBatches", r.Ingester.TotalBatches, |
||||
"Ingester.TotalLinesSent", r.Ingester.TotalLinesSent, |
||||
|
||||
"Ingester.BytesUncompressed", humanize.Bytes(uint64(r.Ingester.BytesUncompressed)), |
||||
"Ingester.LinesUncompressed", r.Ingester.LinesUncompressed, |
||||
"Ingester.BytesDecompressed", humanize.Bytes(uint64(r.Ingester.BytesDecompressed)), |
||||
"Ingester.LinesDecompressed", r.Ingester.LinesDecompressed, |
||||
"Ingester.BytesCompressed", humanize.Bytes(uint64(r.Ingester.BytesCompressed)), |
||||
"Ingester.TotalDuplicates", r.Ingester.TotalDuplicates, |
||||
|
||||
"Store.TotalChunksRef", r.Store.TotalChunksRef, |
||||
"Store.TotalDownloadedChunks", r.Store.TotalDownloadedChunks, |
||||
"Store.TimeDownloadingChunks", r.Store.TimeDownloadingChunks, |
||||
|
||||
"Store.BytesUncompressed", humanize.Bytes(uint64(r.Store.BytesUncompressed)), |
||||
"Store.LinesUncompressed", r.Store.LinesUncompressed, |
||||
"Store.BytesDecompressed", humanize.Bytes(uint64(r.Store.BytesDecompressed)), |
||||
"Store.LinesDecompressed", r.Store.LinesDecompressed, |
||||
"Store.BytesCompressed", humanize.Bytes(uint64(r.Store.BytesCompressed)), |
||||
"Store.TotalDuplicates", r.Store.TotalDuplicates, |
||||
|
||||
"Summary.BytesProcessedPerSeconds", humanize.Bytes(uint64(r.Summary.BytesProcessedPerSeconds)), |
||||
"Summary.LinesProcessedPerSeconds", r.Summary.LinesProcessedPerSeconds, |
||||
"Summary.TotalBytesProcessed", humanize.Bytes(uint64(r.Summary.TotalBytesProcessed)), |
||||
"Summary.TotalLinesProcessed", r.Summary.TotalLinesProcessed, |
||||
"Summary.ExecTime", r.Summary.ExecTime, |
||||
) |
||||
} |
||||
|
||||
// Summary is the summary of a query statistics.
|
||||
type Summary struct { |
||||
BytesProcessedPerSeconds int64 // Total bytes processed per seconds.
|
||||
LinesProcessedPerSeconds int64 // Total lines processed per seconds.
|
||||
TotalBytesProcessed int64 // Total bytes processed.
|
||||
TotalLinesProcessed int64 // Total lines processed.
|
||||
ExecTime time.Duration // Execution time.
|
||||
} |
||||
|
||||
// Ingester is the statistics result for ingesters queries.
|
||||
type Ingester struct { |
||||
IngesterData |
||||
ChunkData |
||||
TotalReached int |
||||
} |
||||
|
||||
// Store is the statistics result of the store.
|
||||
type Store struct { |
||||
StoreData |
||||
ChunkData |
||||
} |
||||
|
||||
// NewContext creates a new statistics context
|
||||
func NewContext(ctx context.Context) context.Context { |
||||
ctx = injectTrailerCollector(ctx) |
||||
ctx = context.WithValue(ctx, storeKey, &StoreData{}) |
||||
ctx = context.WithValue(ctx, chunksKey, &ChunkData{}) |
||||
ctx = context.WithValue(ctx, ingesterKey, &IngesterData{}) |
||||
return ctx |
||||
} |
||||
|
||||
// ChunkData contains chunks specific statistics.
|
||||
type ChunkData struct { |
||||
BytesUncompressed int64 // Total bytes processed but was already in memory. (found in the headchunk)
|
||||
LinesUncompressed int64 // Total lines processed but was already in memory. (found in the headchunk)
|
||||
BytesDecompressed int64 // Total bytes decompressed and processed from chunks.
|
||||
LinesDecompressed int64 // Total lines decompressed and processed from chunks.
|
||||
BytesCompressed int64 // Total bytes of compressed chunks (blocks) processed.
|
||||
TotalDuplicates int64 // Total duplicates found while processing.
|
||||
} |
||||
|
||||
// GetChunkData returns the chunks statistics data from the current context.
|
||||
func GetChunkData(ctx context.Context) *ChunkData { |
||||
res, ok := ctx.Value(chunksKey).(*ChunkData) |
||||
if !ok { |
||||
return &ChunkData{} |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// IngesterData contains ingester specific statistics.
|
||||
type IngesterData struct { |
||||
TotalChunksMatched int64 // Total of chunks matched by the query from ingesters
|
||||
TotalBatches int64 // Total of batches sent from ingesters.
|
||||
TotalLinesSent int64 // Total lines sent by ingesters.
|
||||
} |
||||
|
||||
// GetIngesterData returns the ingester statistics data from the current context.
|
||||
func GetIngesterData(ctx context.Context) *IngesterData { |
||||
res, ok := ctx.Value(ingesterKey).(*IngesterData) |
||||
if !ok { |
||||
return &IngesterData{} |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// StoreData contains store specific statistics.
|
||||
type StoreData struct { |
||||
TotalChunksRef int64 // The total of chunk reference fetched from index.
|
||||
TotalDownloadedChunks int64 // Total number of chunks fetched.
|
||||
TimeDownloadingChunks time.Duration // Time spent fetching chunks.
|
||||
} |
||||
|
||||
// GetStoreData returns the store statistics data from the current context.
|
||||
func GetStoreData(ctx context.Context) *StoreData { |
||||
res, ok := ctx.Value(storeKey).(*StoreData) |
||||
if !ok { |
||||
return &StoreData{} |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// Snapshot compute query statistics from a context using the total exec time.
|
||||
func Snapshot(ctx context.Context, execTime time.Duration) Result { |
||||
var res Result |
||||
// ingester data is decoded from grpc trailers.
|
||||
res.Ingester = decodeTrailers(ctx) |
||||
// collect data from store.
|
||||
s, ok := ctx.Value(storeKey).(*StoreData) |
||||
if ok { |
||||
res.Store.StoreData = *s |
||||
} |
||||
// collect data from chunks iteration.
|
||||
c, ok := ctx.Value(chunksKey).(*ChunkData) |
||||
if ok { |
||||
res.Store.ChunkData = *c |
||||
} |
||||
|
||||
// calculate the summary
|
||||
res.Summary.TotalBytesProcessed = res.Store.BytesDecompressed + res.Store.BytesUncompressed + |
||||
res.Ingester.BytesDecompressed + res.Ingester.BytesUncompressed |
||||
res.Summary.BytesProcessedPerSeconds = |
||||
int64(float64(res.Summary.TotalBytesProcessed) / |
||||
execTime.Seconds()) |
||||
res.Summary.TotalLinesProcessed = res.Store.LinesDecompressed + res.Store.LinesUncompressed + |
||||
res.Ingester.LinesDecompressed + res.Ingester.LinesUncompressed |
||||
res.Summary.LinesProcessedPerSeconds = |
||||
int64(float64(res.Summary.TotalLinesProcessed) / |
||||
execTime.Seconds()) |
||||
res.Summary.ExecTime = execTime |
||||
return res |
||||
} |
@ -0,0 +1,92 @@ |
||||
package stats |
||||
|
||||
import ( |
||||
"context" |
||||
"testing" |
||||
"time" |
||||
|
||||
jsoniter "github.com/json-iterator/go" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestSnapshot(t *testing.T) { |
||||
ctx := NewContext(context.Background()) |
||||
|
||||
GetChunkData(ctx).BytesUncompressed += 10 |
||||
GetChunkData(ctx).LinesUncompressed += 20 |
||||
GetChunkData(ctx).BytesDecompressed += 40 |
||||
GetChunkData(ctx).LinesDecompressed += 20 |
||||
GetChunkData(ctx).BytesCompressed += 30 |
||||
GetChunkData(ctx).TotalDuplicates += 10 |
||||
|
||||
GetStoreData(ctx).TotalChunksRef += 50 |
||||
GetStoreData(ctx).TotalDownloadedChunks += 60 |
||||
GetStoreData(ctx).TimeDownloadingChunks += time.Second |
||||
|
||||
fakeIngesterQuery(ctx) |
||||
fakeIngesterQuery(ctx) |
||||
|
||||
res := Snapshot(ctx, 2*time.Second) |
||||
expected := Result{ |
||||
Ingester: Ingester{ |
||||
IngesterData: IngesterData{ |
||||
TotalChunksMatched: 200, |
||||
TotalBatches: 50, |
||||
TotalLinesSent: 60, |
||||
}, |
||||
ChunkData: ChunkData{ |
||||
BytesUncompressed: 10, |
||||
LinesUncompressed: 20, |
||||
BytesDecompressed: 24, |
||||
LinesDecompressed: 40, |
||||
BytesCompressed: 60, |
||||
TotalDuplicates: 2, |
||||
}, |
||||
TotalReached: 2, |
||||
}, |
||||
Store: Store{ |
||||
StoreData: StoreData{ |
||||
TotalChunksRef: 50, |
||||
TotalDownloadedChunks: 60, |
||||
TimeDownloadingChunks: time.Second, |
||||
}, |
||||
ChunkData: ChunkData{ |
||||
BytesUncompressed: 10, |
||||
LinesUncompressed: 20, |
||||
BytesDecompressed: 40, |
||||
LinesDecompressed: 20, |
||||
BytesCompressed: 30, |
||||
TotalDuplicates: 10, |
||||
}, |
||||
}, |
||||
Summary: Summary{ |
||||
ExecTime: 2 * time.Second, |
||||
BytesProcessedPerSeconds: int64(42), |
||||
LinesProcessedPerSeconds: int64(50), |
||||
TotalBytesProcessed: int64(84), |
||||
TotalLinesProcessed: int64(100), |
||||
}, |
||||
} |
||||
require.Equal(t, expected, res) |
||||
} |
||||
|
||||
func fakeIngesterQuery(ctx context.Context) { |
||||
d, _ := ctx.Value(trailersKey).(*trailerCollector) |
||||
meta := d.addTrailer() |
||||
|
||||
c, _ := jsoniter.MarshalToString(ChunkData{ |
||||
BytesUncompressed: 5, |
||||
LinesUncompressed: 10, |
||||
BytesDecompressed: 12, |
||||
LinesDecompressed: 20, |
||||
BytesCompressed: 30, |
||||
TotalDuplicates: 1, |
||||
}) |
||||
meta.Set(chunkDataKey, c) |
||||
i, _ := jsoniter.MarshalToString(IngesterData{ |
||||
TotalChunksMatched: 100, |
||||
TotalBatches: 25, |
||||
TotalLinesSent: 30, |
||||
}) |
||||
meta.Set(ingesterDataKey, i) |
||||
} |
@ -0,0 +1,113 @@ |
||||
package stats |
||||
|
||||
import ( |
||||
"context" |
||||
"sync" |
||||
|
||||
"github.com/cortexproject/cortex/pkg/util" |
||||
"github.com/go-kit/kit/log/level" |
||||
jsoniter "github.com/json-iterator/go" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/metadata" |
||||
) |
||||
|
||||
const ( |
||||
ingesterDataKey = "ingester_data" |
||||
chunkDataKey = "chunk_data" |
||||
) |
||||
|
||||
type trailerCollector struct { |
||||
trailers []*metadata.MD |
||||
sync.Mutex |
||||
} |
||||
|
||||
func (c *trailerCollector) addTrailer() *metadata.MD { |
||||
c.Lock() |
||||
defer c.Unlock() |
||||
meta := metadata.MD{} |
||||
c.trailers = append(c.trailers, &meta) |
||||
return &meta |
||||
} |
||||
|
||||
func injectTrailerCollector(ctx context.Context) context.Context { |
||||
return context.WithValue(ctx, trailersKey, &trailerCollector{}) |
||||
} |
||||
|
||||
// CollectTrailer register a new trailer that can be collected by the engine.
|
||||
func CollectTrailer(ctx context.Context) grpc.CallOption { |
||||
d, ok := ctx.Value(trailersKey).(*trailerCollector) |
||||
if !ok { |
||||
return grpc.EmptyCallOption{} |
||||
|
||||
} |
||||
return grpc.Trailer(d.addTrailer()) |
||||
} |
||||
|
||||
func SendAsTrailer(ctx context.Context, stream grpc.ServerStream) { |
||||
trailer, err := encodeTrailer(ctx) |
||||
if err != nil { |
||||
level.Warn(util.Logger).Log("msg", "failed to encode trailer", "err", err) |
||||
return |
||||
} |
||||
stream.SetTrailer(trailer) |
||||
} |
||||
|
||||
func encodeTrailer(ctx context.Context) (metadata.MD, error) { |
||||
meta := metadata.MD{} |
||||
ingData, ok := ctx.Value(ingesterKey).(*IngesterData) |
||||
if ok { |
||||
data, err := jsoniter.MarshalToString(ingData) |
||||
if err != nil { |
||||
return meta, err |
||||
} |
||||
meta.Set(ingesterDataKey, data) |
||||
} |
||||
chunkData, ok := ctx.Value(chunksKey).(*ChunkData) |
||||
if ok { |
||||
data, err := jsoniter.MarshalToString(chunkData) |
||||
if err != nil { |
||||
return meta, err |
||||
} |
||||
meta.Set(chunkDataKey, data) |
||||
} |
||||
return meta, nil |
||||
} |
||||
|
||||
func decodeTrailers(ctx context.Context) Ingester { |
||||
var res Ingester |
||||
collector, ok := ctx.Value(trailersKey).(*trailerCollector) |
||||
if !ok { |
||||
return res |
||||
} |
||||
res.TotalReached = len(collector.trailers) |
||||
for _, meta := range collector.trailers { |
||||
ing := decodeTrailer(meta) |
||||
res.TotalChunksMatched += ing.TotalChunksMatched |
||||
res.TotalBatches += ing.TotalBatches |
||||
res.TotalLinesSent += ing.TotalLinesSent |
||||
res.BytesUncompressed += ing.BytesUncompressed |
||||
res.LinesUncompressed += ing.LinesUncompressed |
||||
res.BytesDecompressed += ing.BytesDecompressed |
||||
res.LinesDecompressed += ing.LinesDecompressed |
||||
res.BytesCompressed += ing.BytesCompressed |
||||
res.TotalDuplicates += ing.TotalDuplicates |
||||
} |
||||
return res |
||||
} |
||||
|
||||
func decodeTrailer(meta *metadata.MD) Ingester { |
||||
var res Ingester |
||||
values := meta.Get(ingesterDataKey) |
||||
if len(values) == 1 { |
||||
if err := jsoniter.UnmarshalFromString(values[0], &res.IngesterData); err != nil { |
||||
level.Warn(util.Logger).Log("msg", "could not unmarshal ingester data", "err", err) |
||||
} |
||||
} |
||||
values = meta.Get(chunkDataKey) |
||||
if len(values) == 1 { |
||||
if err := jsoniter.UnmarshalFromString(values[0], &res.ChunkData); err != nil { |
||||
level.Warn(util.Logger).Log("msg", "could not unmarshal chunk data", "err", err) |
||||
} |
||||
} |
||||
return res |
||||
} |
@ -0,0 +1,110 @@ |
||||
package stats |
||||
|
||||
import ( |
||||
"context" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"testing" |
||||
|
||||
"github.com/grafana/loki/pkg/logproto" |
||||
"github.com/stretchr/testify/require" |
||||
"google.golang.org/grpc" |
||||
"google.golang.org/grpc/test/bufconn" |
||||
) |
||||
|
||||
const bufSize = 1024 * 1024 |
||||
|
||||
var lis *bufconn.Listener |
||||
var server *grpc.Server |
||||
|
||||
func init() { |
||||
lis = bufconn.Listen(bufSize) |
||||
server = grpc.NewServer() |
||||
} |
||||
|
||||
func bufDialer(context.Context, string) (net.Conn, error) { |
||||
return lis.Dial() |
||||
} |
||||
|
||||
func TestCollectTrailer(t *testing.T) { |
||||
ctx := context.Background() |
||||
conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) |
||||
if err != nil { |
||||
t.Fatalf("Failed to dial bufnet: %v", err) |
||||
} |
||||
defer conn.Close() |
||||
ing := ingesterFn(func(req *logproto.QueryRequest, s logproto.Querier_QueryServer) error { |
||||
ingCtx := NewContext(s.Context()) |
||||
defer SendAsTrailer(ingCtx, s) |
||||
GetIngesterData(ingCtx).TotalChunksMatched++ |
||||
GetIngesterData(ingCtx).TotalBatches = +2 |
||||
GetIngesterData(ingCtx).TotalLinesSent = +3 |
||||
GetChunkData(ingCtx).BytesUncompressed++ |
||||
GetChunkData(ingCtx).LinesUncompressed++ |
||||
GetChunkData(ingCtx).BytesDecompressed++ |
||||
GetChunkData(ingCtx).LinesDecompressed++ |
||||
GetChunkData(ingCtx).BytesCompressed++ |
||||
GetChunkData(ingCtx).TotalDuplicates++ |
||||
return nil |
||||
}) |
||||
logproto.RegisterQuerierServer(server, ing) |
||||
go func() { |
||||
if err := server.Serve(lis); err != nil { |
||||
log.Fatalf("Server exited with error: %v", err) |
||||
} |
||||
}() |
||||
|
||||
ingClient := logproto.NewQuerierClient(conn) |
||||
|
||||
ctx = NewContext(ctx) |
||||
|
||||
// query the ingester twice.
|
||||
clientStream, err := ingClient.Query(ctx, &logproto.QueryRequest{}, CollectTrailer(ctx)) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
_, err = clientStream.Recv() |
||||
if err != nil && err != io.EOF { |
||||
t.Fatal(err) |
||||
} |
||||
clientStream, err = ingClient.Query(ctx, &logproto.QueryRequest{}, CollectTrailer(ctx)) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
_, err = clientStream.Recv() |
||||
if err != nil && err != io.EOF { |
||||
t.Fatal(err) |
||||
} |
||||
err = clientStream.CloseSend() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
res := decodeTrailers(ctx) |
||||
require.Equal(t, 2, res.TotalReached) |
||||
require.Equal(t, int64(2), res.TotalChunksMatched) |
||||
require.Equal(t, int64(4), res.TotalBatches) |
||||
require.Equal(t, int64(6), res.TotalLinesSent) |
||||
require.Equal(t, int64(2), res.BytesUncompressed) |
||||
require.Equal(t, int64(2), res.LinesUncompressed) |
||||
require.Equal(t, int64(2), res.BytesDecompressed) |
||||
require.Equal(t, int64(2), res.LinesDecompressed) |
||||
require.Equal(t, int64(2), res.BytesCompressed) |
||||
require.Equal(t, int64(2), res.TotalDuplicates) |
||||
} |
||||
|
||||
type ingesterFn func(*logproto.QueryRequest, logproto.Querier_QueryServer) error |
||||
|
||||
func (i ingesterFn) Query(req *logproto.QueryRequest, s logproto.Querier_QueryServer) error { |
||||
return i(req, s) |
||||
} |
||||
func (ingesterFn) Label(context.Context, *logproto.LabelRequest) (*logproto.LabelResponse, error) { |
||||
return nil, nil |
||||
} |
||||
func (ingesterFn) Tail(*logproto.TailRequest, logproto.Querier_TailServer) error { return nil } |
||||
func (ingesterFn) Series(context.Context, *logproto.SeriesRequest) (*logproto.SeriesResponse, error) { |
||||
return nil, nil |
||||
} |
||||
func (ingesterFn) TailersCount(context.Context, *logproto.TailersCountRequest) (*logproto.TailersCountResponse, error) { |
||||
return nil, nil |
||||
} |
@ -0,0 +1,308 @@ |
||||
/* |
||||
* |
||||
* Copyright 2017 gRPC authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
* |
||||
*/ |
||||
|
||||
// Package bufconn provides a net.Conn implemented by a buffer and related
|
||||
// dialing and listening functionality.
|
||||
package bufconn |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"sync" |
||||
"time" |
||||
) |
||||
|
||||
// Listener implements a net.Listener that creates local, buffered net.Conns
|
||||
// via its Accept and Dial method.
|
||||
type Listener struct { |
||||
mu sync.Mutex |
||||
sz int |
||||
ch chan net.Conn |
||||
done chan struct{} |
||||
} |
||||
|
||||
// Implementation of net.Error providing timeout
|
||||
type netErrorTimeout struct { |
||||
error |
||||
} |
||||
|
||||
func (e netErrorTimeout) Timeout() bool { return true } |
||||
func (e netErrorTimeout) Temporary() bool { return false } |
||||
|
||||
var errClosed = fmt.Errorf("closed") |
||||
var errTimeout net.Error = netErrorTimeout{error: fmt.Errorf("i/o timeout")} |
||||
|
||||
// Listen returns a Listener that can only be contacted by its own Dialers and
|
||||
// creates buffered connections between the two.
|
||||
func Listen(sz int) *Listener { |
||||
return &Listener{sz: sz, ch: make(chan net.Conn), done: make(chan struct{})} |
||||
} |
||||
|
||||
// Accept blocks until Dial is called, then returns a net.Conn for the server
|
||||
// half of the connection.
|
||||
func (l *Listener) Accept() (net.Conn, error) { |
||||
select { |
||||
case <-l.done: |
||||
return nil, errClosed |
||||
case c := <-l.ch: |
||||
return c, nil |
||||
} |
||||
} |
||||
|
||||
// Close stops the listener.
|
||||
func (l *Listener) Close() error { |
||||
l.mu.Lock() |
||||
defer l.mu.Unlock() |
||||
select { |
||||
case <-l.done: |
||||
// Already closed.
|
||||
break |
||||
default: |
||||
close(l.done) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
// Addr reports the address of the listener.
|
||||
func (l *Listener) Addr() net.Addr { return addr{} } |
||||
|
||||
// Dial creates an in-memory full-duplex network connection, unblocks Accept by
|
||||
// providing it the server half of the connection, and returns the client half
|
||||
// of the connection.
|
||||
func (l *Listener) Dial() (net.Conn, error) { |
||||
p1, p2 := newPipe(l.sz), newPipe(l.sz) |
||||
select { |
||||
case <-l.done: |
||||
return nil, errClosed |
||||
case l.ch <- &conn{p1, p2}: |
||||
return &conn{p2, p1}, nil |
||||
} |
||||
} |
||||
|
||||
type pipe struct { |
||||
mu sync.Mutex |
||||
|
||||
// buf contains the data in the pipe. It is a ring buffer of fixed capacity,
|
||||
// with r and w pointing to the offset to read and write, respsectively.
|
||||
//
|
||||
// Data is read between [r, w) and written to [w, r), wrapping around the end
|
||||
// of the slice if necessary.
|
||||
//
|
||||
// The buffer is empty if r == len(buf), otherwise if r == w, it is full.
|
||||
//
|
||||
// w and r are always in the range [0, cap(buf)) and [0, len(buf)].
|
||||
buf []byte |
||||
w, r int |
||||
|
||||
wwait sync.Cond |
||||
rwait sync.Cond |
||||
|
||||
// Indicate that a write/read timeout has occurred
|
||||
wtimedout bool |
||||
rtimedout bool |
||||
|
||||
wtimer *time.Timer |
||||
rtimer *time.Timer |
||||
|
||||
closed bool |
||||
writeClosed bool |
||||
} |
||||
|
||||
func newPipe(sz int) *pipe { |
||||
p := &pipe{buf: make([]byte, 0, sz)} |
||||
p.wwait.L = &p.mu |
||||
p.rwait.L = &p.mu |
||||
|
||||
p.wtimer = time.AfterFunc(0, func() {}) |
||||
p.rtimer = time.AfterFunc(0, func() {}) |
||||
return p |
||||
} |
||||
|
||||
func (p *pipe) empty() bool { |
||||
return p.r == len(p.buf) |
||||
} |
||||
|
||||
func (p *pipe) full() bool { |
||||
return p.r < len(p.buf) && p.r == p.w |
||||
} |
||||
|
||||
func (p *pipe) Read(b []byte) (n int, err error) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
// Block until p has data.
|
||||
for { |
||||
if p.closed { |
||||
return 0, io.ErrClosedPipe |
||||
} |
||||
if !p.empty() { |
||||
break |
||||
} |
||||
if p.writeClosed { |
||||
return 0, io.EOF |
||||
} |
||||
if p.rtimedout { |
||||
return 0, errTimeout |
||||
} |
||||
|
||||
p.rwait.Wait() |
||||
} |
||||
wasFull := p.full() |
||||
|
||||
n = copy(b, p.buf[p.r:len(p.buf)]) |
||||
p.r += n |
||||
if p.r == cap(p.buf) { |
||||
p.r = 0 |
||||
p.buf = p.buf[:p.w] |
||||
} |
||||
|
||||
// Signal a blocked writer, if any
|
||||
if wasFull { |
||||
p.wwait.Signal() |
||||
} |
||||
|
||||
return n, nil |
||||
} |
||||
|
||||
func (p *pipe) Write(b []byte) (n int, err error) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
if p.closed { |
||||
return 0, io.ErrClosedPipe |
||||
} |
||||
for len(b) > 0 { |
||||
// Block until p is not full.
|
||||
for { |
||||
if p.closed || p.writeClosed { |
||||
return 0, io.ErrClosedPipe |
||||
} |
||||
if !p.full() { |
||||
break |
||||
} |
||||
if p.wtimedout { |
||||
return 0, errTimeout |
||||
} |
||||
|
||||
p.wwait.Wait() |
||||
} |
||||
wasEmpty := p.empty() |
||||
|
||||
end := cap(p.buf) |
||||
if p.w < p.r { |
||||
end = p.r |
||||
} |
||||
x := copy(p.buf[p.w:end], b) |
||||
b = b[x:] |
||||
n += x |
||||
p.w += x |
||||
if p.w > len(p.buf) { |
||||
p.buf = p.buf[:p.w] |
||||
} |
||||
if p.w == cap(p.buf) { |
||||
p.w = 0 |
||||
} |
||||
|
||||
// Signal a blocked reader, if any.
|
||||
if wasEmpty { |
||||
p.rwait.Signal() |
||||
} |
||||
} |
||||
return n, nil |
||||
} |
||||
|
||||
func (p *pipe) Close() error { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.closed = true |
||||
// Signal all blocked readers and writers to return an error.
|
||||
p.rwait.Broadcast() |
||||
p.wwait.Broadcast() |
||||
return nil |
||||
} |
||||
|
||||
func (p *pipe) closeWrite() error { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.writeClosed = true |
||||
// Signal all blocked readers and writers to return an error.
|
||||
p.rwait.Broadcast() |
||||
p.wwait.Broadcast() |
||||
return nil |
||||
} |
||||
|
||||
type conn struct { |
||||
io.Reader |
||||
io.Writer |
||||
} |
||||
|
||||
func (c *conn) Close() error { |
||||
err1 := c.Reader.(*pipe).Close() |
||||
err2 := c.Writer.(*pipe).closeWrite() |
||||
if err1 != nil { |
||||
return err1 |
||||
} |
||||
return err2 |
||||
} |
||||
|
||||
func (c *conn) SetDeadline(t time.Time) error { |
||||
c.SetReadDeadline(t) |
||||
c.SetWriteDeadline(t) |
||||
return nil |
||||
} |
||||
|
||||
func (c *conn) SetReadDeadline(t time.Time) error { |
||||
p := c.Reader.(*pipe) |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.rtimer.Stop() |
||||
p.rtimedout = false |
||||
if !t.IsZero() { |
||||
p.rtimer = time.AfterFunc(time.Until(t), func() { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.rtimedout = true |
||||
p.rwait.Broadcast() |
||||
}) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (c *conn) SetWriteDeadline(t time.Time) error { |
||||
p := c.Writer.(*pipe) |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.wtimer.Stop() |
||||
p.wtimedout = false |
||||
if !t.IsZero() { |
||||
p.wtimer = time.AfterFunc(time.Until(t), func() { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.wtimedout = true |
||||
p.wwait.Broadcast() |
||||
}) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (*conn) LocalAddr() net.Addr { return addr{} } |
||||
func (*conn) RemoteAddr() net.Addr { return addr{} } |
||||
|
||||
type addr struct{} |
||||
|
||||
func (addr) Network() string { return "bufconn" } |
||||
func (addr) String() string { return "bufconn" } |
Loading…
Reference in new issue