package util //nolint:revive import ( "bytes" "context" "encoding/json" "errors" "flag" "fmt" "html/template" "io" "net/http" "net/url" "strings" "github.com/go-kit/log" "github.com/go-kit/log/level" "github.com/gogo/protobuf/proto" "github.com/golang/snappy" attribute "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "gopkg.in/yaml.v2" ) const messageSizeLargerErrFmt = "%w than max (%d vs %d)" var ErrMessageSizeTooLarge = errors.New("message size too large") const ( HTTPRateLimited = "rate_limited" HTTPServerError = "server_error" HTTPErrorUnknown = "unknown" HTTPClientError = "client_error" ) // IsRequestBodyTooLarge returns true if the error is "http: request body too large". func IsRequestBodyTooLarge(err error) bool { return err != nil && strings.Contains(err.Error(), "http: request body too large") } // BasicAuth configures basic authentication for HTTP clients. type BasicAuth struct { Username string `yaml:"basic_auth_username"` Password string `yaml:"basic_auth_password"` } func (b *BasicAuth) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { f.StringVar(&b.Username, prefix+"basic-auth-username", "", "HTTP Basic authentication username. It overrides the username set in the URL (if any).") f.StringVar(&b.Password, prefix+"basic-auth-password", "", "HTTP Basic authentication password. It overrides the password set in the URL (if any).") } // IsEnabled returns false if basic authentication isn't enabled. func (b BasicAuth) IsEnabled() bool { return b.Username != "" || b.Password != "" } // HeaderAuth condigures header based authorization for HTTP clients. type HeaderAuth struct { Type string `yaml:"type,omitempty"` Credentials string `yaml:"credentials,omitempty"` CredentialsFile string `yaml:"credentials_file,omitempty"` } func (h *HeaderAuth) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { f.StringVar(&h.Type, prefix+"type", "Bearer", "HTTP Header authorization type (default: Bearer).") f.StringVar(&h.Credentials, prefix+"credentials", "", "HTTP Header authorization credentials.") f.StringVar(&h.CredentialsFile, prefix+"credentials-file", "", "HTTP Header authorization credentials file.") } // IsEnabled returns false if header authorization isn't enabled. func (h HeaderAuth) IsEnabled() bool { return h.Credentials != "" || h.CredentialsFile != "" } // WriteJSONResponse writes some JSON as a HTTP response. func WriteJSONResponse(w http.ResponseWriter, v interface{}) { w.Header().Set("Content-Type", "application/json") data, err := json.Marshal(v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } // We ignore errors here, because we cannot do anything about them. // Write will trigger sending Status code, so we cannot send a different status code afterwards. // Also this isn't internal error, but error communicating with client. _, _ = w.Write(data) } // WriteYAMLResponse writes some YAML as a HTTP response. func WriteYAMLResponse(w http.ResponseWriter, v interface{}) { // There is not standardised content-type for YAML, text/plain ensures the // YAML is displayed in the browser instead of offered as a download w.Header().Set("Content-Type", "text/plain; charset=utf-8") data, err := yaml.Marshal(v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } // We ignore errors here, because we cannot do anything about them. // Write will trigger sending Status code, so we cannot send a different status code afterwards. // Also this isn't internal error, but error communicating with client. _, _ = w.Write(data) } // Sends message as text/plain response with 200 status code. func WriteTextResponse(w http.ResponseWriter, message string) { w.Header().Set("Content-Type", "text/plain") // Ignore inactionable errors. _, _ = w.Write([]byte(message)) } // Sends message as text/html response with 200 status code. func WriteHTMLResponse(w http.ResponseWriter, message string) { w.Header().Set("Content-Type", "text/html") // Ignore inactionable errors. _, _ = w.Write([]byte(message)) } // RenderHTTPResponse either responds with json or a rendered html page using the passed in template // by checking the Accepts header func RenderHTTPResponse(w http.ResponseWriter, v interface{}, t *template.Template, r *http.Request) { accept := r.Header.Get("Accept") if strings.Contains(accept, "application/json") { WriteJSONResponse(w, v) return } err := t.Execute(w, v) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } // StreamWriteYAMLResponse stream writes data as http response func StreamWriteYAMLResponse(w http.ResponseWriter, iter chan interface{}, logger log.Logger) { w.Header().Set("Content-Type", "application/yaml") for v := range iter { data, err := yaml.Marshal(v) if err != nil { level.Error(logger).Log("msg", "yaml marshal failed", "err", err) continue } _, err = w.Write(data) if err != nil { level.Error(logger).Log("msg", "write http response failed", "err", err) return } } } // CompressionType for encoding and decoding requests and responses. type CompressionType int // Values for CompressionType const ( NoCompression CompressionType = iota RawSnappy ) // ParseProtoReader parses a compressed proto from an io.Reader. func ParseProtoReader(ctx context.Context, reader io.Reader, expectedSize, maxSize int, req proto.Message, compression CompressionType) error { sp := trace.SpanFromContext(ctx) sp.AddEvent("util.ParseProtoRequest[start reading]") body, err := decompressRequest(reader, expectedSize, maxSize, compression, sp) if err != nil { return err } sp.AddEvent("util.ParseProtoRequest[unmarshal]", trace.WithAttributes(attribute.Int("size", len(body)))) // We re-implement proto.Unmarshal here as it calls XXX_Unmarshal first, // which we can't override without upsetting golint. req.Reset() if u, ok := req.(proto.Unmarshaler); ok { err = u.Unmarshal(body) } else { err = proto.NewBuffer(body).Unmarshal(req) } if err != nil { return err } return nil } func decompressRequest(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp trace.Span) (body []byte, err error) { defer func() { if err != nil && len(body) > maxSize { err = fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(body), maxSize) } }() if expectedSize > maxSize { return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, expectedSize, maxSize) } buffer, ok := tryBufferFromReader(reader) if ok { body, err = decompressFromBuffer(buffer, maxSize, compression, sp) return } body, err = decompressFromReader(reader, expectedSize, maxSize, compression, sp) return } func decompressFromReader(reader io.Reader, expectedSize, maxSize int, compression CompressionType, sp trace.Span) ([]byte, error) { var ( buf bytes.Buffer body []byte err error ) if expectedSize > 0 { buf.Grow(expectedSize + bytes.MinRead) // extra space guarantees no reallocation } // Read from LimitReader with limit max+1. So if the underlying // reader is over limit, the result will be bigger than max. reader = io.LimitReader(reader, int64(maxSize)+1) switch compression { case NoCompression: _, err = buf.ReadFrom(reader) body = buf.Bytes() case RawSnappy: _, err = buf.ReadFrom(reader) if err != nil { return nil, err } body, err = decompressFromBuffer(&buf, maxSize, RawSnappy, sp) } return body, err } func decompressFromBuffer(buffer *bytes.Buffer, maxSize int, compression CompressionType, sp trace.Span) ([]byte, error) { bufBytes := buffer.Bytes() if len(bufBytes) > maxSize { return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, len(bufBytes), maxSize) } switch compression { case NoCompression: return bufBytes, nil case RawSnappy: sp.AddEvent("util.ParseProtoRequest[decompress]", trace.WithAttributes( attribute.Int("size", len(bufBytes)), )) size, err := snappy.DecodedLen(bufBytes) if err != nil { return nil, err } if size > maxSize { return nil, fmt.Errorf(messageSizeLargerErrFmt, ErrMessageSizeTooLarge, size, maxSize) } body, err := snappy.Decode(nil, bufBytes) if err != nil { return nil, err } return body, nil } return nil, nil } // tryBufferFromReader attempts to cast the reader to a `*bytes.Buffer` this is possible when using httpgrpc. // If it fails it will return nil and false. func tryBufferFromReader(reader io.Reader) (*bytes.Buffer, bool) { if bufReader, ok := reader.(interface { BytesBuffer() *bytes.Buffer }); ok && bufReader != nil { return bufReader.BytesBuffer(), true } return nil, false } // SerializeProtoResponse serializes a protobuf response into an HTTP response. func SerializeProtoResponse(w http.ResponseWriter, resp proto.Message, compression CompressionType) error { data, err := proto.Marshal(resp) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return fmt.Errorf("error marshaling proto response: %v", err) } switch compression { case NoCompression: case RawSnappy: data = snappy.Encode(nil, data) } if _, err := w.Write(data); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return fmt.Errorf("error sending proto response: %v", err) } return nil } func FlagFromValues(values url.Values, key string, d bool) bool { switch strings.ToLower(values.Get(key)) { case "t", "true", "1": return true case "f", "false", "0": return false default: return d } } func IsValidURL(endpoint string) bool { u, err := url.Parse(endpoint) if err != nil { return false } return u.Scheme != "" && u.Host != "" } func ErrorTypeFromHTTPStatus(status int) string { errorType := HTTPErrorUnknown if status == 429 { errorType = HTTPRateLimited } else if status/100 == 5 { errorType = HTTPServerError } else if status/100 != 2 { errorType = HTTPClientError } return errorType } func IsError(status int) bool { return status < 200 || status >= 300 } func IsServerError(status int) bool { return status/100 == 5 } func IsRateLimited(status int) bool { return status == 429 }