diff --git a/pkg/util/errors.go b/pkg/util/errors.go index 87efe4185d..70b1475f74 100644 --- a/pkg/util/errors.go +++ b/pkg/util/errors.go @@ -78,6 +78,41 @@ func (es MultiError) Is(target error) bool { return true } +// IsCancel tells if all errors are either context.Canceled or grpc codes.Canceled. +func (es MultiError) IsCancel() bool { + if len(es) == 0 { + return false + } + for _, err := range es { + if errors.Is(err, context.Canceled) { + continue + } + if IsConnCanceled(err) { + continue + } + return false + } + return true +} + +// IsDeadlineExceeded tells if all errors are either context.DeadlineExceeded or grpc codes.DeadlineExceeded. +func (es MultiError) IsDeadlineExceeded() bool { + if len(es) == 0 { + return false + } + for _, err := range es { + if errors.Is(err, context.DeadlineExceeded) { + continue + } + s, ok := status.FromError(err) + if ok && s.Code() == codes.DeadlineExceeded { + continue + } + return false + } + return true +} + // IsConnCanceled returns true, if error is from a closed gRPC connection. // copied from https://github.com/etcd-io/etcd/blob/7f47de84146bdc9225d2080ec8678ca8189a2d2b/clientv3/client.go#L646 func IsConnCanceled(err error) bool { diff --git a/pkg/util/server/error.go b/pkg/util/server/error.go index da41c229ad..05ee1bc510 100644 --- a/pkg/util/server/error.go +++ b/pkg/util/server/error.go @@ -5,6 +5,11 @@ import ( "errors" "net/http" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/grafana/loki/pkg/util" + "github.com/prometheus/prometheus/promql" "github.com/weaveworks/common/httpgrpc" "github.com/weaveworks/common/user" @@ -28,11 +33,24 @@ func WriteError(err error, w http.ResponseWriter) { promErr promql.ErrStorage ) + me, ok := err.(util.MultiError) + if ok && me.IsCancel() { + http.Error(w, ErrClientCanceled, StatusClientClosedRequest) + return + } + if ok && me.IsDeadlineExceeded() { + http.Error(w, ErrDeadlineExceeded, http.StatusGatewayTimeout) + return + } + + s, isRPC := status.FromError(err) switch { case errors.Is(err, context.Canceled) || + (isRPC && s.Code() == codes.Canceled) || (errors.As(err, &promErr) && errors.Is(promErr.Err, context.Canceled)): http.Error(w, ErrClientCanceled, StatusClientClosedRequest) - case errors.Is(err, context.DeadlineExceeded): + case errors.Is(err, context.DeadlineExceeded) || + (isRPC && s.Code() == codes.DeadlineExceeded): http.Error(w, ErrDeadlineExceeded, http.StatusGatewayTimeout) case errors.As(err, &queryErr): http.Error(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/util/server/error_test.go b/pkg/util/server/error_test.go index 948ed2fb9a..07975ab84a 100644 --- a/pkg/util/server/error_test.go +++ b/pkg/util/server/error_test.go @@ -9,6 +9,9 @@ import ( "net/http/httptest" "testing" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "github.com/prometheus/prometheus/promql" "github.com/stretchr/testify/require" "github.com/weaveworks/common/httpgrpc" @@ -29,9 +32,18 @@ func Test_writeError(t *testing.T) { }{ {"cancelled", context.Canceled, ErrClientCanceled, StatusClientClosedRequest}, {"cancelled multi", util.MultiError{context.Canceled, context.Canceled}, ErrClientCanceled, StatusClientClosedRequest}, + {"rpc cancelled", status.New(codes.Canceled, context.Canceled.Error()).Err(), ErrClientCanceled, StatusClientClosedRequest}, + {"rpc cancelled multi", util.MultiError{status.New(codes.Canceled, context.Canceled.Error()).Err(), status.New(codes.Canceled, context.Canceled.Error()).Err()}, ErrClientCanceled, StatusClientClosedRequest}, + {"mixed context and rpc cancelled", util.MultiError{context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).Err()}, ErrClientCanceled, StatusClientClosedRequest}, + {"mixed context, rpc cancelled and another", util.MultiError{errors.New("standard error"), context.Canceled, status.New(codes.Canceled, context.Canceled.Error()).Err()}, "3 errors: standard error; context canceled; rpc error: code = Canceled desc = context canceled", http.StatusInternalServerError}, {"cancelled storage", promql.ErrStorage{Err: context.Canceled}, ErrClientCanceled, StatusClientClosedRequest}, {"orgid", user.ErrNoOrgID, user.ErrNoOrgID.Error(), http.StatusBadRequest}, {"deadline", context.DeadlineExceeded, ErrDeadlineExceeded, http.StatusGatewayTimeout}, + {"deadline multi", util.MultiError{context.DeadlineExceeded, context.DeadlineExceeded}, ErrDeadlineExceeded, http.StatusGatewayTimeout}, + {"rpc deadline", status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err(), ErrDeadlineExceeded, http.StatusGatewayTimeout}, + {"rpc deadline multi", util.MultiError{status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err(), status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, ErrDeadlineExceeded, http.StatusGatewayTimeout}, + {"mixed context and rpc deadline", util.MultiError{context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, ErrDeadlineExceeded, http.StatusGatewayTimeout}, + {"mixed context, rpc deadline and another", util.MultiError{errors.New("standard error"), context.DeadlineExceeded, status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()).Err()}, "3 errors: standard error; context deadline exceeded; rpc error: code = DeadlineExceeded desc = context deadline exceeded", http.StatusInternalServerError}, {"parse error", logqlmodel.ParseError{}, "parse error : ", http.StatusBadRequest}, {"httpgrpc", httpgrpc.Errorf(http.StatusBadRequest, errors.New("foo").Error()), "foo", http.StatusBadRequest}, {"internal", errors.New("foo"), "foo", http.StatusInternalServerError},