Recover from panic in http and grpc handlers. (#2059)

* Recover from panic in http and grpc handlers.

I don't see any good reason to crash any component during a bad request.

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>

* Add alerts to the mixin for panics.

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>

* 😡 gomod

Signed-off-by: Cyril Tovena <cyril.tovena@gmail.com>
pull/2064/head^2
Cyril Tovena 6 years ago committed by GitHub
parent f5b9cffc2e
commit bce4470a5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 1
      go.mod
  2. 2
      pkg/loki/fake_auth.go
  3. 45
      pkg/loki/loki.go
  4. 13
      pkg/loki/modules.go
  5. 84
      pkg/querier/http.go
  6. 31
      pkg/util/server/error.go
  7. 42
      pkg/util/server/error_test.go
  8. 24
      pkg/util/server/middleware.go
  9. 35
      pkg/util/server/middleware_test.go
  10. 46
      pkg/util/server/recovery.go
  11. 44
      pkg/util/server/recovery_test.go
  12. 14
      production/loki-mixin/alerts.libsonnet
  13. 15
      vendor/github.com/grpc-ecosystem/go-grpc-middleware/recovery/doc.go
  14. 53
      vendor/github.com/grpc-ecosystem/go-grpc-middleware/recovery/interceptors.go
  15. 43
      vendor/github.com/grpc-ecosystem/go-grpc-middleware/recovery/options.go
  16. 1
      vendor/modules.txt

@ -26,6 +26,7 @@ require (
github.com/golang/snappy v0.0.1
github.com/gorilla/mux v1.7.1
github.com/gorilla/websocket v1.4.0
github.com/grpc-ecosystem/go-grpc-middleware v1.1.0
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.1-0.20191002090509-6af20e3a5340 // indirect
github.com/grpc-ecosystem/grpc-opentracing v0.0.0-20180507213350-8e809c8a8645
github.com/hashicorp/golang-lru v0.5.3

@ -19,7 +19,7 @@ var fakeHTTPAuthMiddleware = middleware.Func(func(next http.Handler) http.Handle
})
})
var fakeGRPCAuthUniaryMiddleware = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
var fakeGRPCAuthUnaryMiddleware = func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
ctx = user.InjectOrgID(ctx, "fake")
return handler(ctx, req)
}

@ -28,6 +28,7 @@ import (
"github.com/grafana/loki/pkg/querier"
"github.com/grafana/loki/pkg/querier/queryrange"
"github.com/grafana/loki/pkg/storage"
serverutil "github.com/grafana/loki/pkg/util/server"
"github.com/grafana/loki/pkg/util/validation"
)
@ -140,37 +141,33 @@ func New(cfg Config) (*Loki, error) {
}
func (t *Loki) setupAuthMiddleware() {
t.cfg.Server.GRPCMiddleware = []grpc.UnaryServerInterceptor{serverutil.RecoveryGRPCUnaryInterceptor}
t.cfg.Server.GRPCStreamMiddleware = []grpc.StreamServerInterceptor{serverutil.RecoveryGRPCStreamInterceptor}
if t.cfg.AuthEnabled {
t.cfg.Server.GRPCMiddleware = []grpc.UnaryServerInterceptor{
middleware.ServerUserHeaderInterceptor,
}
t.cfg.Server.GRPCStreamMiddleware = []grpc.StreamServerInterceptor{
func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
switch info.FullMethod {
// Don't check auth header on TransferChunks, as we weren't originally
// sending it and this could cause transfers to fail on update.
//
// Also don't check auth /frontend.Frontend/Process, as this handles
// queries for multiple users.
case "/logproto.Ingester/TransferChunks", "/frontend.Frontend/Process":
return handler(srv, ss)
default:
return middleware.StreamServerUserHeaderInterceptor(srv, ss, info, handler)
}
},
}
t.cfg.Server.GRPCMiddleware = append(t.cfg.Server.GRPCMiddleware, middleware.ServerUserHeaderInterceptor)
t.cfg.Server.GRPCStreamMiddleware = append(t.cfg.Server.GRPCStreamMiddleware, GRPCStreamAuthInterceptor)
t.httpAuthMiddleware = middleware.AuthenticateUser
} else {
t.cfg.Server.GRPCMiddleware = []grpc.UnaryServerInterceptor{
fakeGRPCAuthUniaryMiddleware,
}
t.cfg.Server.GRPCStreamMiddleware = []grpc.StreamServerInterceptor{
fakeGRPCAuthStreamMiddleware,
}
t.cfg.Server.GRPCMiddleware = append(t.cfg.Server.GRPCMiddleware, fakeGRPCAuthUnaryMiddleware)
t.cfg.Server.GRPCStreamMiddleware = append(t.cfg.Server.GRPCStreamMiddleware, fakeGRPCAuthStreamMiddleware)
t.httpAuthMiddleware = fakeHTTPAuthMiddleware
}
}
var GRPCStreamAuthInterceptor = func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
switch info.FullMethod {
// Don't check auth header on TransferChunks, as we weren't originally
// sending it and this could cause transfers to fail on update.
//
// Also don't check auth /frontend.Frontend/Process, as this handles
// queries for multiple users.
case "/logproto.Ingester/TransferChunks", "/frontend.Frontend/Process":
return handler(srv, ss)
default:
return middleware.StreamServerUserHeaderInterceptor(srv, ss, info, handler)
}
}
func (t *Loki) initModuleServices(target moduleName) (map[moduleName]services.Service, error) {
servicesMap := map[moduleName]services.Service{}

@ -34,6 +34,7 @@ import (
"github.com/grafana/loki/pkg/querier/queryrange"
loki_storage "github.com/grafana/loki/pkg/storage"
"github.com/grafana/loki/pkg/storage/stores/local"
serverutil "github.com/grafana/loki/pkg/util/server"
"github.com/grafana/loki/pkg/util/validation"
)
@ -146,6 +147,7 @@ func (t *Loki) initDistributor() (services.Service, error) {
}
pushHandler := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
t.httpAuthMiddleware,
).Wrap(http.HandlerFunc(t.distributor.PushHandler))
@ -167,8 +169,9 @@ func (t *Loki) initQuerier() (services.Service, error) {
}
httpMiddleware := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
t.httpAuthMiddleware,
querier.NewPrepopulateMiddleware(),
serverutil.NewPrepopulateMiddleware(),
)
t.server.HTTP.Handle("/loki/api/v1/query_range", httpMiddleware.Wrap(http.HandlerFunc(t.querier.RangeQueryHandler)))
t.server.HTTP.Handle("/loki/api/v1/query", httpMiddleware.Wrap(http.HandlerFunc(t.querier.InstantQueryHandler)))
@ -295,7 +298,13 @@ func (t *Loki) initQueryFrontend() (_ services.Service, err error) {
t.frontend.Wrap(tripperware)
frontend.RegisterFrontendServer(t.server.GRPC, t.frontend)
frontendHandler := queryrange.StatsHTTPMiddleware.Wrap(t.httpAuthMiddleware.Wrap(t.frontend.Handler()))
frontendHandler := middleware.Merge(
serverutil.RecoveryHTTPMiddleware,
queryrange.StatsHTTPMiddleware,
t.httpAuthMiddleware,
serverutil.NewPrepopulateMiddleware(),
).Wrap(t.frontend.Handler())
t.server.HTTP.Handle("/loki/api/v1/query_range", frontendHandler)
t.server.HTTP.Handle("/loki/api/v1/query", frontendHandler)
t.server.HTTP.Handle("/loki/api/v1/label", frontendHandler)

@ -11,7 +11,6 @@ import (
"github.com/prometheus/prometheus/pkg/labels"
"github.com/prometheus/prometheus/promql"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/middleware"
"github.com/weaveworks/common/user"
"github.com/grafana/loki/pkg/loghttp"
@ -19,13 +18,11 @@ import (
"github.com/grafana/loki/pkg/logql"
"github.com/grafana/loki/pkg/logql/marshal"
marshal_legacy "github.com/grafana/loki/pkg/logql/marshal/legacy"
serverutil "github.com/grafana/loki/pkg/util/server"
)
const (
wsPingPeriod = 1 * time.Second
// StatusClientClosedRequest is the status code for when a client request cancellation of an http request
StatusClientClosedRequest = 499
)
type QueryResponse struct {
@ -41,24 +38,24 @@ func (q *Querier) RangeQueryHandler(w http.ResponseWriter, r *http.Request) {
request, err := loghttp.ParseRangeQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
if err := q.validateEntriesLimits(ctx, request.Limit); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Interval, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
if err := marshal.WriteQueryResponseJSON(result, w); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
@ -71,24 +68,24 @@ func (q *Querier) InstantQueryHandler(w http.ResponseWriter, r *http.Request) {
request, err := loghttp.ParseInstantQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
if err := q.validateEntriesLimits(ctx, request.Limit); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
query := q.engine.NewInstantQuery(request.Query, request.Ts, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
if err := marshal.WriteQueryResponseJSON(result, w); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
@ -101,41 +98,41 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) {
request, err := loghttp.ParseRangeQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
request.Query, err = parseRegexQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
expr, err := logql.ParseExpr(request.Query)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
// short circuit metric queries
if _, ok := expr.(logql.SampleExpr); ok {
writeError(httpgrpc.Errorf(http.StatusBadRequest, "legacy endpoints only support %s result type", logql.ValueTypeStreams), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, "legacy endpoints only support %s result type", logql.ValueTypeStreams), w)
return
}
if err := q.validateEntriesLimits(ctx, request.Limit); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
query := q.engine.NewRangeQuery(request.Query, request.Start, request.End, request.Step, request.Interval, request.Direction, request.Limit)
result, err := query.Exec(ctx)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
if err := marshal_legacy.WriteQueryResponseJSON(result, w); err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
@ -144,13 +141,13 @@ func (q *Querier) LogQueryHandler(w http.ResponseWriter, r *http.Request) {
func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseLabelQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
resp, err := q.Label(r.Context(), req)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
@ -160,7 +157,7 @@ func (q *Querier) LabelHandler(w http.ResponseWriter, r *http.Request) {
err = marshal_legacy.WriteLabelResponseJSON(*resp, w)
}
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
@ -174,13 +171,13 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseTailQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
req.Query, err = parseRegexQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
@ -281,39 +278,23 @@ func (q *Querier) TailHandler(w http.ResponseWriter, r *http.Request) {
func (q *Querier) SeriesHandler(w http.ResponseWriter, r *http.Request) {
req, err := loghttp.ParseSeriesQuery(r)
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
serverutil.WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
resp, err := q.Series(r.Context(), req)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
err = marshal.WriteSeriesResponseJSON(*resp, w)
if err != nil {
writeError(err, w)
serverutil.WriteError(err, w)
return
}
}
// NewPrepopulateMiddleware creates a middleware which will parse incoming http forms.
// This is important because some endpoints can POST x-www-form-urlencoded bodies instead of GET w/ query strings.
func NewPrepopulateMiddleware() middleware.Interface {
return middleware.Func(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
writeError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
next.ServeHTTP(w, req)
})
})
}
// parseRegexQuery parses regex and query querystring from httpRequest and returns the combined LogQL query.
// This is used only to keep regexp query string support until it gets fully deprecated.
func parseRegexQuery(httpRequest *http.Request) (string, error) {
@ -329,23 +310,6 @@ func parseRegexQuery(httpRequest *http.Request) (string, error) {
return query, nil
}
func writeError(err error, w http.ResponseWriter) {
switch {
case err == context.Canceled:
http.Error(w, err.Error(), StatusClientClosedRequest)
case err == context.DeadlineExceeded:
http.Error(w, err.Error(), http.StatusGatewayTimeout)
case logql.IsParseError(err):
http.Error(w, err.Error(), http.StatusBadRequest)
default:
if grpcErr, ok := httpgrpc.HTTPResponseFromError(err); ok {
http.Error(w, string(grpcErr.Body), int(grpcErr.Code))
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func (q *Querier) validateEntriesLimits(ctx context.Context, limit uint32) error {
userID, err := user.ExtractOrgID(ctx)
if err != nil {

@ -0,0 +1,31 @@
package server
import (
"context"
"net/http"
"github.com/weaveworks/common/httpgrpc"
"github.com/grafana/loki/pkg/logql"
)
// StatusClientClosedRequest is the status code for when a client request cancellation of an http request
const StatusClientClosedRequest = 499
// WriteError write a go error with the correct status code.
func WriteError(err error, w http.ResponseWriter) {
switch {
case err == context.Canceled:
http.Error(w, err.Error(), StatusClientClosedRequest)
case err == context.DeadlineExceeded:
http.Error(w, err.Error(), http.StatusGatewayTimeout)
case logql.IsParseError(err):
http.Error(w, err.Error(), http.StatusBadRequest)
default:
if grpcErr, ok := httpgrpc.HTTPResponseFromError(err); ok {
http.Error(w, string(grpcErr.Body), int(grpcErr.Code))
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}

@ -0,0 +1,42 @@
package server
import (
"context"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"github.com/weaveworks/common/httpgrpc"
"github.com/grafana/loki/pkg/logql"
)
func Test_writeError(t *testing.T) {
for _, tt := range []struct {
name string
err error
msg string
expectedStatus int
}{
{"cancelled", context.Canceled, context.Canceled.Error(), StatusClientClosedRequest},
{"deadline", context.DeadlineExceeded, context.DeadlineExceeded.Error(), http.StatusGatewayTimeout},
{"parse error", logql.ParseError{}, "parse error : ", http.StatusBadRequest},
{"httpgrpc", httpgrpc.Errorf(http.StatusBadRequest, errors.New("foo").Error()), "foo", http.StatusBadRequest},
{"internal", errors.New("foo"), "foo", http.StatusInternalServerError},
} {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
WriteError(tt.err, rec)
require.Equal(t, tt.expectedStatus, rec.Result().StatusCode)
b, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
require.Equal(t, tt.msg, string(b[:len(b)-1]))
})
}
}

@ -0,0 +1,24 @@
package server
import (
"net/http"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/middleware"
)
// NewPrepopulateMiddleware creates a middleware which will parse incoming http forms.
// This is important because some endpoints can POST x-www-form-urlencoded bodies instead of GET w/ query strings.
func NewPrepopulateMiddleware() middleware.Interface {
return middleware.Func(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
WriteError(httpgrpc.Errorf(http.StatusBadRequest, err.Error()), w)
return
}
next.ServeHTTP(w, req)
})
})
}

@ -1,20 +1,14 @@
package querier
package server
import (
"bytes"
"context"
"errors"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/require"
"github.com/weaveworks/common/httpgrpc"
"github.com/grafana/loki/pkg/logql"
)
func TestPrepopulate(t *testing.T) {
@ -110,30 +104,3 @@ func TestPrepopulate(t *testing.T) {
})
}
}
func Test_writeError(t *testing.T) {
for _, tt := range []struct {
name string
err error
msg string
expectedStatus int
}{
{"cancelled", context.Canceled, context.Canceled.Error(), StatusClientClosedRequest},
{"deadline", context.DeadlineExceeded, context.DeadlineExceeded.Error(), http.StatusGatewayTimeout},
{"parse error", logql.ParseError{}, "parse error : ", http.StatusBadRequest},
{"httpgrpc", httpgrpc.Errorf(http.StatusBadRequest, errors.New("foo").Error()), "foo", http.StatusBadRequest},
{"internal", errors.New("foo"), "foo", http.StatusInternalServerError},
} {
t.Run(tt.name, func(t *testing.T) {
rec := httptest.NewRecorder()
writeError(tt.err, rec)
require.Equal(t, tt.expectedStatus, rec.Result().StatusCode)
b, err := ioutil.ReadAll(rec.Result().Body)
if err != nil {
t.Fatal(err)
}
require.Equal(t, tt.msg, string(b[:len(b)-1]))
})
}
}

@ -0,0 +1,46 @@
package server
import (
"fmt"
"net/http"
"os"
"runtime"
grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/weaveworks/common/httpgrpc"
"github.com/weaveworks/common/middleware"
)
const maxStacksize = 8 * 1024
var (
panicTotal = promauto.NewCounter(prometheus.CounterOpts{
Namespace: "loki",
Name: "panic_total",
Help: "The total number of panic triggered",
})
RecoveryHTTPMiddleware = middleware.Func(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() {
if p := recover(); p != nil {
WriteError(onPanic(p), w)
}
}()
next.ServeHTTP(w, req)
})
})
RecoveryGRPCStreamInterceptor = grpc_recovery.StreamServerInterceptor(grpc_recovery.WithRecoveryHandler(onPanic))
RecoveryGRPCUnaryInterceptor = grpc_recovery.UnaryServerInterceptor(grpc_recovery.WithRecoveryHandler(onPanic))
)
func onPanic(p interface{}) error {
stack := make([]byte, maxStacksize)
stack = stack[:runtime.Stack(stack, true)]
// keep a multiline stack
fmt.Fprintf(os.Stderr, "panic: %v\n%s", p, stack)
panicTotal.Inc()
return httpgrpc.Errorf(http.StatusInternalServerError, "error while processing request: %v", p)
}

@ -0,0 +1,44 @@
package server
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
func Test_onPanic(t *testing.T) {
rec := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodGet, "foo", nil)
if err != nil {
t.Fatal(err)
}
RecoveryHTTPMiddleware.
Wrap(http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) {
panic("foo bar")
})).
ServeHTTP(rec, req)
require.Equal(t, http.StatusInternalServerError, rec.Code)
require.Error(t, RecoveryGRPCStreamInterceptor(nil, fakeStream{}, nil, grpc.StreamHandler(func(srv interface{}, stream grpc.ServerStream) error {
panic("foo")
})))
_, err = RecoveryGRPCUnaryInterceptor(context.Background(), nil, nil, grpc.UnaryHandler(func(ctx context.Context, req interface{}) (interface{}, error) {
panic("foo")
}))
require.Error(t, err)
}
type fakeStream struct{}
func (fakeStream) SetHeader(_ metadata.MD) error { return nil }
func (fakeStream) SendHeader(_ metadata.MD) error { return nil }
func (fakeStream) SetTrailer(_ metadata.MD) {}
func (fakeStream) Context() context.Context { return context.Background() }
func (fakeStream) SendMsg(m interface{}) error { return nil }
func (fakeStream) RecvMsg(m interface{}) error { return nil }

@ -22,6 +22,20 @@
|||,
},
},
{
alert: 'LokiRequestPanics',
expr: |||
sum(increase(loki_panic_total[10m])) by (namespace, job) > 0
|||,
labels: {
severity: 'critical',
},
annotations: {
message: |||
{{ $labels.job }} is experiencing {{ printf "%.2f" $value }}% increase of panics.
|||,
},
},
{
alert: 'LokiRequestLatency',
expr: |||

@ -0,0 +1,15 @@
// Copyright 2017 David Ackroyd. All Rights Reserved.
// See LICENSE for licensing terms.
/*
`grpc_recovery` are intereceptors that recover from gRPC handler panics.
Server Side Recovery Middleware
By default a panic will be converted into a gRPC error with `code.Internal`.
Handling can be customised by providing an alternate recovery function.
Please see examples for simple examples of use.
*/
package grpc_recovery

@ -0,0 +1,53 @@
// Copyright 2017 David Ackroyd. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_recovery
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
// RecoveryHandlerFunc is a function that recovers from the panic `p` by returning an `error`.
type RecoveryHandlerFunc func(p interface{}) (err error)
// RecoveryHandlerFuncContext is a function that recovers from the panic `p` by returning an `error`.
// The context can be used to extract request scoped metadata and context values.
type RecoveryHandlerFuncContext func(ctx context.Context, p interface{}) (err error)
// UnaryServerInterceptor returns a new unary server interceptor for panic recovery.
func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
o := evaluateOptions(opts)
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
defer func() {
if r := recover(); r != nil {
err = recoverFrom(ctx, r, o.recoveryHandlerFunc)
}
}()
return handler(ctx, req)
}
}
// StreamServerInterceptor returns a new streaming server interceptor for panic recovery.
func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
o := evaluateOptions(opts)
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) {
defer func() {
if r := recover(); r != nil {
err = recoverFrom(stream.Context(), r, o.recoveryHandlerFunc)
}
}()
return handler(srv, stream)
}
}
func recoverFrom(ctx context.Context, p interface{}, r RecoveryHandlerFuncContext) error {
if r == nil {
return grpc.Errorf(codes.Internal, "%s", p)
}
return r(ctx, p)
}

@ -0,0 +1,43 @@
// Copyright 2017 David Ackroyd. All Rights Reserved.
// See LICENSE for licensing terms.
package grpc_recovery
import "context"
var (
defaultOptions = &options{
recoveryHandlerFunc: nil,
}
)
type options struct {
recoveryHandlerFunc RecoveryHandlerFuncContext
}
func evaluateOptions(opts []Option) *options {
optCopy := &options{}
*optCopy = *defaultOptions
for _, o := range opts {
o(optCopy)
}
return optCopy
}
type Option func(*options)
// WithRecoveryHandler customizes the function for recovering from a panic.
func WithRecoveryHandler(f RecoveryHandlerFunc) Option {
return func(o *options) {
o.recoveryHandlerFunc = RecoveryHandlerFuncContext(func(ctx context.Context, p interface{}) error {
return f(p)
})
}
}
// WithRecoveryHandlerContext customizes the function for recovering from a panic.
func WithRecoveryHandlerContext(f RecoveryHandlerFuncContext) Option {
return func(o *options) {
o.recoveryHandlerFunc = f
}
}

@ -405,6 +405,7 @@ github.com/gorilla/mux
github.com/gorilla/websocket
# github.com/grpc-ecosystem/go-grpc-middleware v1.1.0
github.com/grpc-ecosystem/go-grpc-middleware
github.com/grpc-ecosystem/go-grpc-middleware/recovery
github.com/grpc-ecosystem/go-grpc-middleware/tags
github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing
github.com/grpc-ecosystem/go-grpc-middleware/util/metautils

Loading…
Cancel
Save