From bf06e45bcfc7745ee26798fdfbbbe31ecfb3bd3f Mon Sep 17 00:00:00 2001 From: Joao Marcal Date: Tue, 26 May 2026 15:41:43 +0100 Subject: [PATCH] feat(operator): Add passthrough-gateway component (#20382) Signed-off-by: Joao Marcal Co-authored-by: Robert Jacob --- .github/workflows/operator-images.yaml | 9 + operator/Makefile | 10 + operator/cmd/passthrough-gateway/main.go | 60 ++++ .../internal/passthroughgateway/config.go | 126 +++++++ .../internal/passthroughgateway/metrics.go | 43 +++ .../passthroughgateway/metrics_test.go | 57 ++++ operator/internal/passthroughgateway/proxy.go | 154 +++++++++ .../internal/passthroughgateway/proxy_test.go | 117 +++++++ .../internal/passthroughgateway/server.go | 177 ++++++++++ operator/internal/passthroughgateway/tls.go | 231 +++++++++++++ .../internal/passthroughgateway/tls_test.go | 307 ++++++++++++++++++ operator/passthrough-gateway.Dockerfile | 26 ++ 12 files changed, 1317 insertions(+) create mode 100644 operator/cmd/passthrough-gateway/main.go create mode 100644 operator/internal/passthroughgateway/config.go create mode 100644 operator/internal/passthroughgateway/metrics.go create mode 100644 operator/internal/passthroughgateway/metrics_test.go create mode 100644 operator/internal/passthroughgateway/proxy.go create mode 100644 operator/internal/passthroughgateway/proxy_test.go create mode 100644 operator/internal/passthroughgateway/server.go create mode 100644 operator/internal/passthroughgateway/tls.go create mode 100644 operator/internal/passthroughgateway/tls_test.go create mode 100644 operator/passthrough-gateway.Dockerfile diff --git a/.github/workflows/operator-images.yaml b/.github/workflows/operator-images.yaml index e0440800ea..eef5673240 100644 --- a/.github/workflows/operator-images.yaml +++ b/.github/workflows/operator-images.yaml @@ -44,3 +44,12 @@ jobs: organization: "openshift-logging" image_name: "storage-size-calculator" tag: "latest" + + publish-openshift-passthrough-gateway: + uses: ./.github/workflows/operator-reusable-image-build.yml + with: + dockerfile: "operator/passthrough-gateway.Dockerfile" + registry: "quay.io" + organization: "openshift-logging" + image_name: "passthrough-gateway" + tag: "latest" diff --git a/operator/Makefile b/operator/Makefile index 10501cce85..5e802f88d2 100644 --- a/operator/Makefile +++ b/operator/Makefile @@ -85,6 +85,7 @@ ifeq ($(USE_IMAGE_DIGESTS), true) endif CALCULATOR_IMG ?= $(REGISTRY_BASE)/storage-size-calculator:latest +PASSTHROUGH_GATEWAY_IMG ?= $(REGISTRY_BASE)/passthrough-gateway:latest GO_FILES := $(shell find . -type f -name '*.go') @@ -313,6 +314,15 @@ oci-build-calculator: ## Build the calculator image oci-push-calculator: ## Push the calculator image $(OCI_RUNTIME) push $(CALCULATOR_IMG) +##@ Passthrough Gateway +.PHONY: oci-build-passthrough-gateway +oci-build-passthrough-gateway: ## Build the passthrough gateway image + $(OCI_RUNTIME) build -f passthrough-gateway.Dockerfile -t $(PASSTHROUGH_GATEWAY_IMG) . + +.PHONY: oci-push-passthrough-gateway +oci-push-passthrough-gateway: ## Push the passthrough gateway image + $(OCI_RUNTIME) push $(PASSTHROUGH_GATEWAY_IMG) + ##@ Website TYPES_TARGET := $(shell find api/loki -type f -iname "*_types.go") diff --git a/operator/cmd/passthrough-gateway/main.go b/operator/cmd/passthrough-gateway/main.go new file mode 100644 index 0000000000..6a2a1412f7 --- /dev/null +++ b/operator/cmd/passthrough-gateway/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "context" + "flag" + "os" + "os/signal" + "syscall" + + "github.com/ViaQ/logerr/v2/log" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + + "github.com/grafana/loki/operator/internal/passthroughgateway" +) + +func main() { + logger := log.NewLogger("lokistack-gateway") + + cfg := &passthroughgateway.Config{} + f := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + cfg.RegisterFlags(f) + + if err := f.Parse(os.Args[1:]); err != nil { + logger.Error(err, "failed to parse flags") + os.Exit(1) + } + + if err := cfg.Validate(); err != nil { + logger.Error(err, "invalid configuration") + os.Exit(1) + } + + logger.Info("starting gateway", + "listen-addr", cfg.ListenAddr, + "admin-addr", cfg.AdminAddr, + "loki-distributor-endpoint", cfg.Loki.DistributorEndpoint, + "loki-query-frontend-endpoint", cfg.Loki.QueryFrontendEndpoint, + ) + + reg := prometheus.NewRegistry() + reg.MustRegister(collectors.NewGoCollector()) + reg.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) + + server, err := passthroughgateway.NewServer(cfg, logger, reg) + if err != nil { + logger.Error(err, "failed to create server") + os.Exit(1) + } + + ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer cancel() + + if err := server.Run(ctx); err != nil { + logger.Error(err, "server error") + os.Exit(1) + } + + logger.Info("lokistack gateway stopped") +} diff --git a/operator/internal/passthroughgateway/config.go b/operator/internal/passthroughgateway/config.go new file mode 100644 index 0000000000..9b6ce523f9 --- /dev/null +++ b/operator/internal/passthroughgateway/config.go @@ -0,0 +1,126 @@ +package passthroughgateway + +import ( + "flag" + "net/url" + "strings" + "time" + + "github.com/ViaQ/logerr/v2/kverrors" +) + +type LokiConfig struct { + DistributorEndpoint string + QueryFrontendEndpoint string + Timeout time.Duration + CAFile string + CertFile string + KeyFile string +} + +type Config struct { + ListenAddr string + AdminAddr string + DefaultTenant string + Loki LokiConfig + // Server TLS configuration + TLSCertFile string + TLSKeyFile string + // Client TLS configuration + TLSClientCAFile string + TLSClientAuth string + // Generic TLS configuration + TLSMinVersion string + TLSMaxVersion string + TLSCipherSuites StringSlice + TLSCurvePrefs StringSlice +} + +// StringSlice is a custom flag type for comma-separated strings. +type StringSlice []string + +func (s *StringSlice) String() string { + return strings.Join(*s, ",") +} + +func (s *StringSlice) Set(value string) error { + if value == "" { + return nil + } + *s = strings.Split(value, ",") + return nil +} + +// TLSOptions returns the TLS configuration options. +func (c *Config) TLSOptions() *TLSConfig { + return &TLSConfig{ + MinVersion: c.TLSMinVersion, + MaxVersion: c.TLSMaxVersion, + CipherSuites: c.TLSCipherSuites, + CurvePrefs: c.TLSCurvePrefs, + ClientAuth: c.TLSClientAuth, + } +} + +// RegisterFlags registers the configuration flags. +func (c *Config) RegisterFlags(f *flag.FlagSet) { + f.StringVar(&c.ListenAddr, "listen-addr", ":8080", "Address for the server to listen on.") + f.StringVar(&c.AdminAddr, "admin-addr", ":9090", "Address for admin endpoints (metrics, health, readiness).") + f.StringVar(&c.DefaultTenant, "default-tenant", "", "Default tenant ID to use when X-Scope-OrgID header is not set. If empty, requests without X-Scope-OrgID are rejected.") + // Loki upstream configuration flags + f.StringVar(&c.Loki.DistributorEndpoint, "loki-distributor-endpoint", "", "Upstream URL of the Loki distributor (write path).") + f.StringVar(&c.Loki.QueryFrontendEndpoint, "loki-query-frontend-endpoint", "", "Upstream URL of the Loki query frontend (read path).") + f.DurationVar(&c.Loki.Timeout, "loki-timeout", 60*time.Second, "Timeout for upstream Loki requests. Set to 0 for no timeout.") + f.StringVar(&c.Loki.CAFile, "loki-ca-file", "", "Path to the CA certificate for verifying the Loki server.") + f.StringVar(&c.Loki.CertFile, "loki-cert-file", "", "Path to the client certificate for Loki mTLS.") + f.StringVar(&c.Loki.KeyFile, "loki-key-file", "", "Path to the client private key for Loki mTLS.") + // Server TLS configuration flags + f.StringVar(&c.TLSCertFile, "tls-cert-file", "", "Path to the server TLS certificate file.") + f.StringVar(&c.TLSKeyFile, "tls-key-file", "", "Path to the server TLS private key file.") + // Client TLS configuration flags + f.StringVar(&c.TLSClientCAFile, "tls-client-ca-file", "", "Path to the CA certificate for verifying client certificates.") + f.StringVar(&c.TLSClientAuth, "tls-client-auth", "NoClientCert", "Client certificate auth mode (NoClientCert, RequestClientCert, RequireAnyClientCert, VerifyClientCertIfGiven, RequireAndVerifyClientCert).") + // Generic TLS configuration flags + f.StringVar(&c.TLSMinVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13).") + f.StringVar(&c.TLSMaxVersion, "tls-max-version", "", "Maximum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13). Empty means no maximum.") + f.Var(&c.TLSCipherSuites, "tls-cipher-suites", "Comma-separated list of TLS cipher suites.") + f.Var(&c.TLSCurvePrefs, "tls-curve-preferences", "Comma-separated list of curve preferences (X25519, CurveP256, CurveP384, CurveP521).") +} + +// Validate checks that all required configuration is provided and valid. +func (c *Config) Validate() error { + if c.Loki.DistributorEndpoint == "" { + return kverrors.New("-loki-distributor-endpoint is required") + } + + u, err := url.Parse(c.Loki.DistributorEndpoint) + if err != nil { + return kverrors.New("-loki-distributor-endpoint is not a valid URL", "endpoint", c.Loki.DistributorEndpoint, "err", err) + } + if u.Host == "" { + return kverrors.New("-loki-distributor-endpoint is missing a host", "endpoint", c.Loki.DistributorEndpoint) + } + + if c.Loki.QueryFrontendEndpoint == "" { + return kverrors.New("-loki-query-frontend-endpoint is required") + } + + u, err = url.Parse(c.Loki.QueryFrontendEndpoint) + if err != nil { + return kverrors.New("-loki-query-frontend-endpoint is not a valid URL", "endpoint", c.Loki.QueryFrontendEndpoint, "err", err) + } + if u.Host == "" { + return kverrors.New("-loki-query-frontend-endpoint is missing a host", "endpoint", c.Loki.QueryFrontendEndpoint) + } + + if c.DefaultTenant != "" && strings.ContainsAny(c.DefaultTenant, "\n") { + return kverrors.New("-default-tenant must not contain newline characters") + } + + return nil +} + +// TLSEnabled returns true if TLS certificates are configured. +func (c *Config) TLSEnabled() bool { + return c.TLSCertFile != "" && c.TLSKeyFile != "" +} diff --git a/operator/internal/passthroughgateway/metrics.go b/operator/internal/passthroughgateway/metrics.go new file mode 100644 index 0000000000..a200a95346 --- /dev/null +++ b/operator/internal/passthroughgateway/metrics.go @@ -0,0 +1,43 @@ +package passthroughgateway + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// metrics holds Prometheus metrics for the LokiStack gateway. +type metrics struct { + RequestsTotal *prometheus.CounterVec + RequestDuration *prometheus.HistogramVec + RequestsInFlight *prometheus.GaugeVec +} + +// newMetrics creates and registers Prometheus metrics for the LokiStack gateway. +func newMetrics(reg prometheus.Registerer) *metrics { + factory := promauto.With(reg) + + return &metrics{ + RequestsTotal: factory.NewCounterVec( + prometheus.CounterOpts{ + Name: "lokistack_gateway_requests_total", + Help: "Total number of requests processed by the LokiStack gateway.", + }, + []string{"method", "route", "status_code"}, + ), + RequestDuration: factory.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "lokistack_gateway_request_duration_seconds", + Help: "Duration of requests processed by the LokiStack gateway.", + Buckets: []float64{.1, 1, 5, 9, 15, 30, 60, 120, 300}, + }, + []string{"method", "route"}, + ), + RequestsInFlight: factory.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "lokistack_gateway_requests_in_flight", + Help: "Current number of requests being processed by the LokiStack gateway.", + }, + []string{"route"}, + ), + } +} diff --git a/operator/internal/passthroughgateway/metrics_test.go b/operator/internal/passthroughgateway/metrics_test.go new file mode 100644 index 0000000000..71afa229b8 --- /dev/null +++ b/operator/internal/passthroughgateway/metrics_test.go @@ -0,0 +1,57 @@ +package passthroughgateway + +import ( + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" +) + +func TestMetricsRequestsTotal(t *testing.T) { + registry := prometheus.NewPedanticRegistry() + metrics := newMetrics(registry) + + metrics.RequestsTotal.WithLabelValues("GET", "read", "200").Inc() + metrics.RequestsTotal.WithLabelValues("POST", "write", "201").Inc() + + expected := ` +# HELP lokistack_gateway_requests_total Total number of requests processed by the LokiStack gateway. +# TYPE lokistack_gateway_requests_total counter +lokistack_gateway_requests_total{method="GET",route="read",status_code="200"} 1 +lokistack_gateway_requests_total{method="POST",route="write",status_code="201"} 1 +` + + err := testutil.CollectAndCompare(metrics.RequestsTotal, strings.NewReader(expected)) + require.NoError(t, err) +} + +func TestMetricsRequestDuration(t *testing.T) { + registry := prometheus.NewPedanticRegistry() + metrics := newMetrics(registry) + + metrics.RequestDuration.WithLabelValues("GET", "/api/logs").Observe(0.5) + + count := testutil.CollectAndCount(metrics.RequestDuration) + require.Equal(t, 1, count) +} + +func TestMetricsRequestsInFlight(t *testing.T) { + registry := prometheus.NewPedanticRegistry() + metrics := newMetrics(registry) + + metrics.RequestsInFlight.WithLabelValues("read").Inc() + metrics.RequestsInFlight.WithLabelValues("write").Inc() + metrics.RequestsInFlight.WithLabelValues("read").Dec() + + expected := ` +# HELP lokistack_gateway_requests_in_flight Current number of requests being processed by the LokiStack gateway. +# TYPE lokistack_gateway_requests_in_flight gauge +lokistack_gateway_requests_in_flight{route="read"} 0 +lokistack_gateway_requests_in_flight{route="write"} 1 +` + + err := testutil.CollectAndCompare(metrics.RequestsInFlight, strings.NewReader(expected)) + require.NoError(t, err) +} diff --git a/operator/internal/passthroughgateway/proxy.go b/operator/internal/passthroughgateway/proxy.go new file mode 100644 index 0000000000..37551f3b6a --- /dev/null +++ b/operator/internal/passthroughgateway/proxy.go @@ -0,0 +1,154 @@ +package passthroughgateway + +import ( + "net/http" + "net/http/httputil" + "net/url" + "slices" + "strconv" + "strings" + "time" + + "github.com/go-logr/logr" +) + +var lokiWritePaths = []string{ + "/loki/api/v1/push", + "/api/prom/push", + "/otlp/v1/logs", +} + +// lokiRouter directs requests to the appropriate Loki upstream (distributor or query-frontend). +type lokiRouter struct { + writeProxy *httputil.ReverseProxy + readProxy *httputil.ReverseProxy + logger logr.Logger + defaultTenant string +} + +type responseWriter struct { + http.ResponseWriter + statusCode int +} + +func (rw *responseWriter) WriteHeader(code int) { + rw.statusCode = code + rw.ResponseWriter.WriteHeader(code) +} + +func (rw *responseWriter) Flush() { + if f, ok := rw.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +func (rw *responseWriter) Unwrap() http.ResponseWriter { + return rw.ResponseWriter +} + +// NewLokiRouter creates a new router that directs requests to the appropriate upstream. +func NewLokiRouter(cfg *Config, logger logr.Logger) (*lokiRouter, error) { + transport, err := newTransport(cfg) + if err != nil { + return nil, err + } + + writeProxy, err := newReverseProxy(cfg.Loki.DistributorEndpoint, transport, logger) + if err != nil { + return nil, err + } + + readProxy, err := newReverseProxy(cfg.Loki.QueryFrontendEndpoint, transport, logger) + if err != nil { + return nil, err + } + + return &lokiRouter{ + writeProxy: writeProxy, + readProxy: readProxy, + logger: logger, + defaultTenant: cfg.DefaultTenant, + }, nil +} + +func newTransport(cfg *Config) (*http.Transport, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = nil + transport.ResponseHeaderTimeout = cfg.Loki.Timeout + + tlsConfig, err := BuildUpstreamTLSConfig(cfg.TLSOptions(), cfg.Loki.CAFile, cfg.Loki.CertFile, cfg.Loki.KeyFile) + if err != nil { + return nil, err + } + transport.TLSClientConfig = tlsConfig + + return transport, nil +} + +func newReverseProxy(upstreamEndpoint string, transport *http.Transport, logger logr.Logger) (*httputil.ReverseProxy, error) { + target, err := url.Parse(upstreamEndpoint) + if err != nil { + return nil, err + } + + proxy := &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + pr.SetURL(target) + pr.Out.Host = target.Host + }, + Transport: transport, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + logger.Error(err, "proxy error", "method", r.Method, "path", r.URL.Path) + w.WriteHeader(http.StatusBadGateway) + }, + } + + return proxy, nil +} + +func isWritePath(path string) bool { + return slices.ContainsFunc(lokiWritePaths, func(writePath string) bool { + return strings.HasPrefix(path, writePath) + }) +} + +func (r *lokiRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Header.Get("X-Scope-OrgID") == "" { + if r.defaultTenant == "" { + r.logger.Error(nil, "missing required header", "header", "X-Scope-OrgID", "path", req.URL.Path, "method", req.Method) + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("Bad request: X-Scope-OrgID header is required")) + return + } + req.Header.Set("X-Scope-OrgID", r.defaultTenant) + } + + if isWritePath(req.URL.Path) { + r.writeProxy.ServeHTTP(w, req) + return + } + r.readProxy.ServeHTTP(w, req) +} + +// instrumentedHandler wraps an http.Handler with metrics instrumentation. +func instrumentedHandler(handler http.Handler, metrics *metrics) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route := "read" + if isWritePath(r.URL.Path) { + route = "write" + } + + inFlight := metrics.RequestsInFlight.WithLabelValues(route) + inFlight.Inc() + defer inFlight.Dec() + + start := time.Now() + wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} + + handler.ServeHTTP(wrapped, r) + + duration := time.Since(start).Seconds() + metrics.RequestDuration.WithLabelValues(r.Method, route).Observe(duration) + metrics.RequestsTotal.WithLabelValues(r.Method, route, strconv.Itoa(wrapped.statusCode)).Inc() + }) +} diff --git a/operator/internal/passthroughgateway/proxy_test.go b/operator/internal/passthroughgateway/proxy_test.go new file mode 100644 index 0000000000..c29c570802 --- /dev/null +++ b/operator/internal/passthroughgateway/proxy_test.go @@ -0,0 +1,117 @@ +package passthroughgateway + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" +) + +func BenchmarkIsWritePath(b *testing.B) { + benchmarks := []struct { + name string + path string + }{ + {"write_loki_push", "/loki/api/v1/push"}, + {"write_prom_push", "/api/prom/push"}, + {"write_otlp", "/otlp/v1/logs"}, + {"write_with_suffix", "/loki/api/v1/push?foo=bar"}, + {"read_query", "/loki/api/v1/query"}, + {"read_labels", "/loki/api/v1/labels"}, + {"read_series", "/loki/api/v1/series"}, + {"read_tail", "/loki/api/v1/tail"}, + {"read_long_path", "/loki/api/v1/query_range?query=rate({app='test'}[5m])&start=1234567890&end=1234567899"}, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = isWritePath(bm.path) + } + }) + } +} + +func TestInstrumentedHandler(t *testing.T) { + tt := []struct { + name string + method string + path string + route string + handlerStatus int + }{ + { + name: "read request with GET", + method: http.MethodGet, + path: "/loki/api/v1/query", + route: "read", + handlerStatus: http.StatusOK, + }, + { + name: "write request with POST", + method: http.MethodPost, + path: "/loki/api/v1/push", + route: "write", + handlerStatus: http.StatusNoContent, + }, + { + name: "write request otlp", + method: http.MethodPost, + path: "/otlp/v1/logs", + route: "write", + handlerStatus: http.StatusOK, + }, + { + name: "read request with error status", + method: http.MethodGet, + path: "/loki/api/v1/labels", + route: "read", + handlerStatus: http.StatusInternalServerError, + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + registry := prometheus.NewPedanticRegistry() + metrics := newMetrics(registry) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.handlerStatus) + }) + + instrumented := instrumentedHandler(handler, metrics) + + req := httptest.NewRequest(tc.method, tc.path, nil) + rec := httptest.NewRecorder() + + instrumented.ServeHTTP(rec, req) + + require.Equal(t, tc.handlerStatus, rec.Code) + + expectedTotal := fmt.Sprintf(` +# HELP lokistack_gateway_requests_total Total number of requests processed by the LokiStack gateway. +# TYPE lokistack_gateway_requests_total counter +lokistack_gateway_requests_total{method="%s",route="%s",status_code="%d"} 1 +`, tc.method, tc.route, tc.handlerStatus) + err := testutil.CollectAndCompare(metrics.RequestsTotal, strings.NewReader(expectedTotal)) + require.NoError(t, err) + + durationCount := testutil.CollectAndCount(metrics.RequestDuration) + require.Equal(t, 1, durationCount) + + expectedInFlight := fmt.Sprintf(` +# HELP lokistack_gateway_requests_in_flight Current number of requests being processed by the LokiStack gateway. +# TYPE lokistack_gateway_requests_in_flight gauge +lokistack_gateway_requests_in_flight{route="%s"} 0 +`, tc.route) + err = testutil.CollectAndCompare(metrics.RequestsInFlight, strings.NewReader(expectedInFlight)) + require.NoError(t, err) + }) + } +} diff --git a/operator/internal/passthroughgateway/server.go b/operator/internal/passthroughgateway/server.go new file mode 100644 index 0000000000..3b82f5bba7 --- /dev/null +++ b/operator/internal/passthroughgateway/server.go @@ -0,0 +1,177 @@ +package passthroughgateway + +import ( + "context" + "errors" + "net/http" + "sync" + "sync/atomic" + "time" + + "github.com/go-logr/logr" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +type Server struct { + config *Config + logger logr.Logger + metrics *metrics + proxyServer *http.Server + adminServer *http.Server + shuttingDown *atomic.Bool +} + +func NewServer(cfg *Config, logger logr.Logger, reg prometheus.Registerer) (*Server, error) { + metrics := newMetrics(reg) + + router, err := NewLokiRouter(cfg, logger) + if err != nil { + return nil, err + } + + proxyHandler := instrumentedHandler(router, metrics) + + var proxyWriteTimeout time.Duration + if cfg.Loki.Timeout > 0 { + proxyWriteTimeout = cfg.Loki.Timeout + 10*time.Second + } + + proxyServer := &http.Server{ + Addr: cfg.ListenAddr, + Handler: proxyHandler, + ReadTimeout: 30 * time.Second, + WriteTimeout: proxyWriteTimeout, + IdleTimeout: 120 * time.Second, + } + + adminServer := &http.Server{ + Addr: cfg.AdminAddr, + ReadTimeout: 10 * time.Second, + } + + if cfg.TLSEnabled() { + tlsConfig, err := BuildServerTLSConfigWithClientAuth(cfg.TLSOptions(), cfg.TLSClientCAFile) + if err != nil { + return nil, err + } + proxyServer.TLSConfig = tlsConfig + + adminTLSConfig, err := BuildServerTLSConfig(cfg.TLSOptions(), cfg.TLSClientCAFile) + if err != nil { + return nil, err + } + adminServer.TLSConfig = adminTLSConfig + } + + adminMux := http.NewServeMux() + adminMux.Handle("/metrics", promhttp.HandlerFor( + reg.(prometheus.Gatherer), + promhttp.HandlerOpts{Registry: reg}, + )) + adminMux.HandleFunc("/live", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + server := &Server{ + config: cfg, + logger: logger, + metrics: metrics, + shuttingDown: &atomic.Bool{}, + } + + adminMux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) { + if server.shuttingDown.Load() { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte("shutting down")) + return + } + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) + + adminServer.Handler = adminMux + adminServer.WriteTimeout = 10 * time.Second + + server.proxyServer = proxyServer + server.adminServer = adminServer + + return server, nil +} + +func (s *Server) Run(ctx context.Context) error { + var wg sync.WaitGroup + errChan := make(chan error, 2) + + wg.Go(func() { + var err error + if s.config.TLSEnabled() { + s.logger.Info("starting HTTPS admin server", "addr", s.config.AdminAddr) + err = s.adminServer.ListenAndServeTLS(s.config.TLSCertFile, s.config.TLSKeyFile) + } else { + s.logger.Info("starting HTTP admin server", "addr", s.config.AdminAddr) + err = s.adminServer.ListenAndServe() + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err + } + }) + + wg.Go(func() { + var err error + if s.config.TLSEnabled() { + s.logger.Info("starting mTLS proxy server", "addr", s.config.ListenAddr) + err = s.proxyServer.ListenAndServeTLS(s.config.TLSCertFile, s.config.TLSKeyFile) + } else { + s.logger.Info("starting HTTP proxy server", "addr", s.config.ListenAddr) + err = s.proxyServer.ListenAndServe() + } + if err != nil && !errors.Is(err, http.ErrServerClosed) { + errChan <- err + } + }) + + var runErr error + select { + case <-ctx.Done(): + s.logger.Info("shutting down servers") + runErr = s.Shutdown() + case err := <-errChan: + runErr = err + _ = s.Shutdown() + } + + wg.Wait() + return runErr +} + +func (s *Server) Shutdown() error { + s.shuttingDown.Store(true) + + var ( + wg sync.WaitGroup + proxyErr error + adminErr error + ) + + wg.Go(func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.proxyServer.Shutdown(ctx); err != nil { + s.logger.Error(err, "error shutting down proxy server") + proxyErr = err + } + }) + + wg.Go(func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := s.adminServer.Shutdown(ctx); err != nil { + s.logger.Error(err, "error shutting down admin server") + adminErr = err + } + }) + + wg.Wait() + return errors.Join(proxyErr, adminErr) +} diff --git a/operator/internal/passthroughgateway/tls.go b/operator/internal/passthroughgateway/tls.go new file mode 100644 index 0000000000..87c5432415 --- /dev/null +++ b/operator/internal/passthroughgateway/tls.go @@ -0,0 +1,231 @@ +package passthroughgateway + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "os" + "strings" + + "github.com/ViaQ/logerr/v2/kverrors" +) + +var errUnknownClientAuthType = errors.New("unknown client auth type") + +type TLSConfig struct { + MinVersion string + MaxVersion string + CipherSuites []string + CurvePrefs []string + ClientAuth string +} + +var tlsClientAuthTypes = map[string]tls.ClientAuthType{ + "NoClientCert": tls.NoClientCert, + "RequestClientCert": tls.RequestClientCert, + "RequireAnyClientCert": tls.RequireAnyClientCert, + "VerifyClientCertIfGiven": tls.VerifyClientCertIfGiven, + "RequireAndVerifyClientCert": tls.RequireAndVerifyClientCert, +} + +var tlsVersions = map[string]uint16{ + "VersionTLS10": tls.VersionTLS10, + "VersionTLS11": tls.VersionTLS11, + "VersionTLS12": tls.VersionTLS12, + "VersionTLS13": tls.VersionTLS13, +} + +var tlsCipherSuites = map[string]uint16{ + // TLS 1.2 cipher suites + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, + "TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + "TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + "TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, + // TLS 1.3 cipher suites (always enabled when TLS 1.3 is used) + "TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, + "TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, + "TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, +} + +var tlsCurveIDs = map[string]tls.CurveID{ + "X25519": tls.X25519, + "CurveP256": tls.CurveP256, + "CurveP384": tls.CurveP384, + "CurveP521": tls.CurveP521, +} + +// BuildServerTLSConfig creates a TLS config for the server with optional mTLS. +func BuildServerTLSConfig(cfg *TLSConfig, clientCAFile string) (*tls.Config, error) { + tlsConfig, err := buildTLSConfig(cfg) + if err != nil { + return nil, err + } + + if clientCAFile != "" { + caCert, err := os.ReadFile(clientCAFile) + if err != nil { + return nil, fmt.Errorf("failed to read client CA file: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, kverrors.New("failed to append client CA certs") + } + + tlsConfig.ClientCAs = caCertPool + } + + return tlsConfig, nil +} + +// BuildServerTLSConfigWithClientAuth creates a TLS config for the server with optional client authentication. +func BuildServerTLSConfigWithClientAuth(cfg *TLSConfig, clientCAFile string) (*tls.Config, error) { + tlsConfig, err := BuildServerTLSConfig(cfg, clientCAFile) + if err != nil { + return nil, err + } + + if clientCAFile != "" { + clientAuth, err := parseClientAuthType(cfg.ClientAuth) + if err != nil { + return nil, err + } + tlsConfig.ClientAuth = clientAuth + } + + return tlsConfig, nil +} + +// BuildUpstreamTLSConfig creates a TLS config for upstream connections with optional client certificate. +func BuildUpstreamTLSConfig(cfg *TLSConfig, upstreamCAFile, upstreamCertFile, upstreamKeyFile string) (*tls.Config, error) { + tlsConfig, err := buildTLSConfig(cfg) + if err != nil { + return nil, err + } + + // Load upstream CA for server verification + if upstreamCAFile != "" { + caCert, err := os.ReadFile(upstreamCAFile) + if err != nil { + return nil, fmt.Errorf("failed to read upstream CA file: %w", err) + } + + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, kverrors.New("failed to append upstream CA certs") + } + tlsConfig.RootCAs = caCertPool + } + + // Load client certificate for upstream mTLS + if upstreamCertFile != "" && upstreamKeyFile != "" { + cert, err := tls.LoadX509KeyPair(upstreamCertFile, upstreamKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load upstream client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + return tlsConfig, nil +} + +func parseClientAuthType(authType string) (tls.ClientAuthType, error) { + if authType == "" { + return tls.RequireAndVerifyClientCert, nil + } + auth, ok := tlsClientAuthTypes[authType] + if !ok { + return 0, fmt.Errorf("%w: %s", errUnknownClientAuthType, authType) + } + return auth, nil +} + +func buildTLSConfig(cfg *TLSConfig) (*tls.Config, error) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + if cfg.MinVersion != "" { + minVersion, err := parseTLSVersion(cfg.MinVersion) + if err != nil { + return nil, fmt.Errorf("invalid minimum TLS version: %w", err) + } + tlsConfig.MinVersion = minVersion + } + + if cfg.MaxVersion != "" { + maxVersion, err := parseTLSVersion(cfg.MaxVersion) + if err != nil { + return nil, fmt.Errorf("invalid maximum TLS version: %w", err) + } + tlsConfig.MaxVersion = maxVersion + } + + if len(cfg.CipherSuites) > 0 { + cipherSuites, err := parseCipherSuites(cfg.CipherSuites) + if err != nil { + return nil, fmt.Errorf("invalid cipher suites: %w", err) + } + tlsConfig.CipherSuites = cipherSuites + } + + if len(cfg.CurvePrefs) > 0 { + curvePrefs, err := parseCurvePreferences(cfg.CurvePrefs) + if err != nil { + return nil, fmt.Errorf("invalid curve preferences: %w", err) + } + tlsConfig.CurvePreferences = curvePrefs + } + + return tlsConfig, nil +} + +func parseTLSVersion(version string) (uint16, error) { + if version == "" { + return 0, nil + } + v, ok := tlsVersions[version] + if !ok { + return 0, kverrors.New("unknown TLS version", "version", version) + } + return v, nil +} + +func parseCipherSuites(ciphers []string) ([]uint16, error) { + if len(ciphers) == 0 { + return nil, nil + } + result := make([]uint16, 0, len(ciphers)) + for _, name := range ciphers { + id, ok := tlsCipherSuites[strings.TrimSpace(name)] + if !ok { + return nil, kverrors.New("unknown cipher suite", "cipher", name) + } + result = append(result, id) + } + return result, nil +} + +func parseCurvePreferences(curves []string) ([]tls.CurveID, error) { + if len(curves) == 0 { + return nil, nil + } + result := make([]tls.CurveID, 0, len(curves)) + for _, name := range curves { + id, ok := tlsCurveIDs[strings.TrimSpace(name)] + if !ok { + return nil, kverrors.New("unknown curve", "curve", name) + } + result = append(result, id) + } + return result, nil +} diff --git a/operator/internal/passthroughgateway/tls_test.go b/operator/internal/passthroughgateway/tls_test.go new file mode 100644 index 0000000000..0ce456cd8e --- /dev/null +++ b/operator/internal/passthroughgateway/tls_test.go @@ -0,0 +1,307 @@ +package passthroughgateway + +import ( + "crypto/tls" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuildTLSConfig(t *testing.T) { + tt := []struct { + desc string + cfg *TLSConfig + wantErr bool + errContains string + validate func(t *testing.T, cfg *tls.Config) + }{ + { + desc: "default config with TLS 1.2 min version", + cfg: &TLSConfig{}, + validate: func(t *testing.T, cfg *tls.Config) { + require.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion) + }, + }, + { + desc: "custom min and maximum version", + cfg: &TLSConfig{ + MinVersion: "VersionTLS12", + MaxVersion: "VersionTLS13", + }, + validate: func(t *testing.T, cfg *tls.Config) { + require.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion) + require.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion) + }, + }, + { + desc: "invalid min version", + cfg: &TLSConfig{ + MinVersion: "TLS99", + }, + wantErr: true, + errContains: "invalid minimum TLS version", + }, + { + desc: "invalid max version", + cfg: &TLSConfig{ + MaxVersion: "invalid", + }, + wantErr: true, + errContains: "invalid maximum TLS version", + }, + { + desc: "valid cipher suites", + cfg: &TLSConfig{ + CipherSuites: []string{ + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + }, + }, + validate: func(t *testing.T, cfg *tls.Config) { + require.Len(t, cfg.CipherSuites, 2) + require.Contains(t, cfg.CipherSuites, uint16(tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)) + require.Contains(t, cfg.CipherSuites, uint16(tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)) + }, + }, + { + desc: "invalid cipher suite", + cfg: &TLSConfig{ + CipherSuites: []string{"INVALID_CIPHER"}, + }, + wantErr: true, + errContains: "invalid cipher suites", + }, + { + desc: "valid curve preferences", + cfg: &TLSConfig{ + CurvePrefs: []string{"X25519", "CurveP256"}, + }, + validate: func(t *testing.T, cfg *tls.Config) { + require.Len(t, cfg.CurvePreferences, 2) + require.Contains(t, cfg.CurvePreferences, tls.X25519) + require.Contains(t, cfg.CurvePreferences, tls.CurveP256) + }, + }, + { + desc: "invalid curve preference", + cfg: &TLSConfig{ + CurvePrefs: []string{"INVALID_CURVE"}, + }, + wantErr: true, + errContains: "invalid curve preferences", + }, + } + + for _, tc := range tt { + t.Run(tc.desc, func(t *testing.T) { + cfg, err := buildTLSConfig(tc.cfg) + if tc.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errContains) + return + } + require.NoError(t, err) + if tc.validate != nil { + tc.validate(t, cfg) + } + }) + } +} + +func TestParseTLSVersion(t *testing.T) { + tt := []struct { + desc string + version string + want uint16 + wantErr bool + }{ + { + desc: "empty version returns zero", + version: "", + want: 0, + }, + { + desc: "TLS 1.0", + version: "VersionTLS10", + want: tls.VersionTLS10, + }, + { + desc: "TLS 1.2", + version: "VersionTLS12", + want: tls.VersionTLS12, + }, + { + desc: "TLS 1.3", + version: "VersionTLS13", + want: tls.VersionTLS13, + }, + { + desc: "unknown version", + version: "TLS99", + wantErr: true, + }, + } + + for _, tc := range tt { + t.Run(tc.desc, func(t *testing.T) { + got, err := parseTLSVersion(tc.version) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestParseClientAuthType(t *testing.T) { + tt := []struct { + desc string + authType string + want tls.ClientAuthType + wantErr bool + }{ + { + desc: "empty defaults to RequireAndVerifyClientCert", + authType: "", + want: tls.RequireAndVerifyClientCert, + }, + { + desc: "NoClientCert", + authType: "NoClientCert", + want: tls.NoClientCert, + }, + { + desc: "RequestClientCert", + authType: "RequestClientCert", + want: tls.RequestClientCert, + }, + { + desc: "RequireAnyClientCert", + authType: "RequireAnyClientCert", + want: tls.RequireAnyClientCert, + }, + { + desc: "VerifyClientCertIfGiven", + authType: "VerifyClientCertIfGiven", + want: tls.VerifyClientCertIfGiven, + }, + { + desc: "RequireAndVerifyClientCert", + authType: "RequireAndVerifyClientCert", + want: tls.RequireAndVerifyClientCert, + }, + { + desc: "unknown auth type", + authType: "InvalidAuthType", + wantErr: true, + }, + } + + for _, tc := range tt { + t.Run(tc.desc, func(t *testing.T) { + got, err := parseClientAuthType(tc.authType) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestParseCipherSuites(t *testing.T) { + tt := []struct { + desc string + ciphers []string + want []uint16 + wantErr bool + }{ + { + desc: "empty returns nil", + ciphers: []string{}, + want: nil, + }, + { + desc: "single cipher", + ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, + want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + }, + { + desc: "multiple ciphers", + ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256"}, + want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_AES_128_GCM_SHA256}, + }, + { + desc: "cipher with whitespace", + ciphers: []string{" TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 "}, + want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + }, + { + desc: "unknown cipher", + ciphers: []string{"UNKNOWN_CIPHER"}, + wantErr: true, + }, + } + + for _, tc := range tt { + t.Run(tc.desc, func(t *testing.T) { + got, err := parseCipherSuites(tc.ciphers) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestParseCurvePreferences(t *testing.T) { + tt := []struct { + desc string + curves []string + want []tls.CurveID + wantErr bool + }{ + { + desc: "empty returns nil", + curves: []string{}, + want: nil, + }, + { + desc: "X25519", + curves: []string{"X25519"}, + want: []tls.CurveID{tls.X25519}, + }, + { + desc: "multiple curves", + curves: []string{"X25519", "CurveP256", "CurveP384"}, + want: []tls.CurveID{tls.X25519, tls.CurveP256, tls.CurveP384}, + }, + { + desc: "curve with whitespace", + curves: []string{" CurveP256 "}, + want: []tls.CurveID{tls.CurveP256}, + }, + { + desc: "unknown curve", + curves: []string{"UNKNOWN_CURVE"}, + wantErr: true, + }, + } + + for _, tc := range tt { + t.Run(tc.desc, func(t *testing.T) { + got, err := parseCurvePreferences(tc.curves) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/operator/passthrough-gateway.Dockerfile b/operator/passthrough-gateway.Dockerfile new file mode 100644 index 0000000000..57898dab7b --- /dev/null +++ b/operator/passthrough-gateway.Dockerfile @@ -0,0 +1,26 @@ +# Build the passthrough-gateway binary +FROM golang:1.26.3 as builder + +WORKDIR /workspace +# Copy the Go Modules manifests +COPY api/ api/ +COPY go.mod go.sum ./ +# cache deps before building and copying source so that we don't need to re-download as much +# and so that source changes don't invalidate our downloaded layer +RUN go mod download + +# Copy the go source +COPY cmd/passthrough-gateway/main.go cmd/passthrough-gateway/main.go +COPY internal/ internal/ + +# Build +RUN CGO_ENABLED=0 GOOS=linux GO111MODULE=on go build -mod=readonly -o passthrough-gateway cmd/passthrough-gateway/main.go + +# Use distroless as minimal base image to package the passthrough-gateway binary +# Refer to https://github.com/GoogleContainerTools/distroless for more details +FROM gcr.io/distroless/static:nonroot +WORKDIR / +COPY --from=builder /workspace/passthrough-gateway . +USER 65532:65532 + +ENTRYPOINT ["/passthrough-gateway"]