feat(operator): Add passthrough-gateway component (#20382)

Signed-off-by: Joao Marcal <jmarcal@redhat.com>
Co-authored-by: Robert Jacob <rojacob@redhat.com>
pull/21989/merge
Joao Marcal 1 week ago committed by GitHub
parent 98d43a23c1
commit bf06e45bcf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 9
      .github/workflows/operator-images.yaml
  2. 10
      operator/Makefile
  3. 60
      operator/cmd/passthrough-gateway/main.go
  4. 126
      operator/internal/passthroughgateway/config.go
  5. 43
      operator/internal/passthroughgateway/metrics.go
  6. 57
      operator/internal/passthroughgateway/metrics_test.go
  7. 154
      operator/internal/passthroughgateway/proxy.go
  8. 117
      operator/internal/passthroughgateway/proxy_test.go
  9. 177
      operator/internal/passthroughgateway/server.go
  10. 231
      operator/internal/passthroughgateway/tls.go
  11. 307
      operator/internal/passthroughgateway/tls_test.go
  12. 26
      operator/passthrough-gateway.Dockerfile

@ -44,3 +44,12 @@ jobs:
organization: "openshift-logging"
image_name: "storage-size-calculator"
tag: "latest"
publish-openshift-passthrough-gateway:
uses: ./.github/workflows/operator-reusable-image-build.yml
with:
dockerfile: "operator/passthrough-gateway.Dockerfile"
registry: "quay.io"
organization: "openshift-logging"
image_name: "passthrough-gateway"
tag: "latest"

@ -85,6 +85,7 @@ ifeq ($(USE_IMAGE_DIGESTS), true)
endif
CALCULATOR_IMG ?= $(REGISTRY_BASE)/storage-size-calculator:latest
PASSTHROUGH_GATEWAY_IMG ?= $(REGISTRY_BASE)/passthrough-gateway:latest
GO_FILES := $(shell find . -type f -name '*.go')
@ -313,6 +314,15 @@ oci-build-calculator: ## Build the calculator image
oci-push-calculator: ## Push the calculator image
$(OCI_RUNTIME) push $(CALCULATOR_IMG)
##@ Passthrough Gateway
.PHONY: oci-build-passthrough-gateway
oci-build-passthrough-gateway: ## Build the passthrough gateway image
$(OCI_RUNTIME) build -f passthrough-gateway.Dockerfile -t $(PASSTHROUGH_GATEWAY_IMG) .
.PHONY: oci-push-passthrough-gateway
oci-push-passthrough-gateway: ## Push the passthrough gateway image
$(OCI_RUNTIME) push $(PASSTHROUGH_GATEWAY_IMG)
##@ Website
TYPES_TARGET := $(shell find api/loki -type f -iname "*_types.go")

@ -0,0 +1,60 @@
package main
import (
"context"
"flag"
"os"
"os/signal"
"syscall"
"github.com/ViaQ/logerr/v2/log"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/collectors"
"github.com/grafana/loki/operator/internal/passthroughgateway"
)
func main() {
logger := log.NewLogger("lokistack-gateway")
cfg := &passthroughgateway.Config{}
f := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
cfg.RegisterFlags(f)
if err := f.Parse(os.Args[1:]); err != nil {
logger.Error(err, "failed to parse flags")
os.Exit(1)
}
if err := cfg.Validate(); err != nil {
logger.Error(err, "invalid configuration")
os.Exit(1)
}
logger.Info("starting gateway",
"listen-addr", cfg.ListenAddr,
"admin-addr", cfg.AdminAddr,
"loki-distributor-endpoint", cfg.Loki.DistributorEndpoint,
"loki-query-frontend-endpoint", cfg.Loki.QueryFrontendEndpoint,
)
reg := prometheus.NewRegistry()
reg.MustRegister(collectors.NewGoCollector())
reg.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}))
server, err := passthroughgateway.NewServer(cfg, logger, reg)
if err != nil {
logger.Error(err, "failed to create server")
os.Exit(1)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
if err := server.Run(ctx); err != nil {
logger.Error(err, "server error")
os.Exit(1)
}
logger.Info("lokistack gateway stopped")
}

@ -0,0 +1,126 @@
package passthroughgateway
import (
"flag"
"net/url"
"strings"
"time"
"github.com/ViaQ/logerr/v2/kverrors"
)
type LokiConfig struct {
DistributorEndpoint string
QueryFrontendEndpoint string
Timeout time.Duration
CAFile string
CertFile string
KeyFile string
}
type Config struct {
ListenAddr string
AdminAddr string
DefaultTenant string
Loki LokiConfig
// Server TLS configuration
TLSCertFile string
TLSKeyFile string
// Client TLS configuration
TLSClientCAFile string
TLSClientAuth string
// Generic TLS configuration
TLSMinVersion string
TLSMaxVersion string
TLSCipherSuites StringSlice
TLSCurvePrefs StringSlice
}
// StringSlice is a custom flag type for comma-separated strings.
type StringSlice []string
func (s *StringSlice) String() string {
return strings.Join(*s, ",")
}
func (s *StringSlice) Set(value string) error {
if value == "" {
return nil
}
*s = strings.Split(value, ",")
return nil
}
// TLSOptions returns the TLS configuration options.
func (c *Config) TLSOptions() *TLSConfig {
return &TLSConfig{
MinVersion: c.TLSMinVersion,
MaxVersion: c.TLSMaxVersion,
CipherSuites: c.TLSCipherSuites,
CurvePrefs: c.TLSCurvePrefs,
ClientAuth: c.TLSClientAuth,
}
}
// RegisterFlags registers the configuration flags.
func (c *Config) RegisterFlags(f *flag.FlagSet) {
f.StringVar(&c.ListenAddr, "listen-addr", ":8080", "Address for the server to listen on.")
f.StringVar(&c.AdminAddr, "admin-addr", ":9090", "Address for admin endpoints (metrics, health, readiness).")
f.StringVar(&c.DefaultTenant, "default-tenant", "", "Default tenant ID to use when X-Scope-OrgID header is not set. If empty, requests without X-Scope-OrgID are rejected.")
// Loki upstream configuration flags
f.StringVar(&c.Loki.DistributorEndpoint, "loki-distributor-endpoint", "", "Upstream URL of the Loki distributor (write path).")
f.StringVar(&c.Loki.QueryFrontendEndpoint, "loki-query-frontend-endpoint", "", "Upstream URL of the Loki query frontend (read path).")
f.DurationVar(&c.Loki.Timeout, "loki-timeout", 60*time.Second, "Timeout for upstream Loki requests. Set to 0 for no timeout.")
f.StringVar(&c.Loki.CAFile, "loki-ca-file", "", "Path to the CA certificate for verifying the Loki server.")
f.StringVar(&c.Loki.CertFile, "loki-cert-file", "", "Path to the client certificate for Loki mTLS.")
f.StringVar(&c.Loki.KeyFile, "loki-key-file", "", "Path to the client private key for Loki mTLS.")
// Server TLS configuration flags
f.StringVar(&c.TLSCertFile, "tls-cert-file", "", "Path to the server TLS certificate file.")
f.StringVar(&c.TLSKeyFile, "tls-key-file", "", "Path to the server TLS private key file.")
// Client TLS configuration flags
f.StringVar(&c.TLSClientCAFile, "tls-client-ca-file", "", "Path to the CA certificate for verifying client certificates.")
f.StringVar(&c.TLSClientAuth, "tls-client-auth", "NoClientCert", "Client certificate auth mode (NoClientCert, RequestClientCert, RequireAnyClientCert, VerifyClientCertIfGiven, RequireAndVerifyClientCert).")
// Generic TLS configuration flags
f.StringVar(&c.TLSMinVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13).")
f.StringVar(&c.TLSMaxVersion, "tls-max-version", "", "Maximum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13). Empty means no maximum.")
f.Var(&c.TLSCipherSuites, "tls-cipher-suites", "Comma-separated list of TLS cipher suites.")
f.Var(&c.TLSCurvePrefs, "tls-curve-preferences", "Comma-separated list of curve preferences (X25519, CurveP256, CurveP384, CurveP521).")
}
// Validate checks that all required configuration is provided and valid.
func (c *Config) Validate() error {
if c.Loki.DistributorEndpoint == "" {
return kverrors.New("-loki-distributor-endpoint is required")
}
u, err := url.Parse(c.Loki.DistributorEndpoint)
if err != nil {
return kverrors.New("-loki-distributor-endpoint is not a valid URL", "endpoint", c.Loki.DistributorEndpoint, "err", err)
}
if u.Host == "" {
return kverrors.New("-loki-distributor-endpoint is missing a host", "endpoint", c.Loki.DistributorEndpoint)
}
if c.Loki.QueryFrontendEndpoint == "" {
return kverrors.New("-loki-query-frontend-endpoint is required")
}
u, err = url.Parse(c.Loki.QueryFrontendEndpoint)
if err != nil {
return kverrors.New("-loki-query-frontend-endpoint is not a valid URL", "endpoint", c.Loki.QueryFrontendEndpoint, "err", err)
}
if u.Host == "" {
return kverrors.New("-loki-query-frontend-endpoint is missing a host", "endpoint", c.Loki.QueryFrontendEndpoint)
}
if c.DefaultTenant != "" && strings.ContainsAny(c.DefaultTenant, "\n") {
return kverrors.New("-default-tenant must not contain newline characters")
}
return nil
}
// TLSEnabled returns true if TLS certificates are configured.
func (c *Config) TLSEnabled() bool {
return c.TLSCertFile != "" && c.TLSKeyFile != ""
}

@ -0,0 +1,43 @@
package passthroughgateway
import (
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)
// metrics holds Prometheus metrics for the LokiStack gateway.
type metrics struct {
RequestsTotal *prometheus.CounterVec
RequestDuration *prometheus.HistogramVec
RequestsInFlight *prometheus.GaugeVec
}
// newMetrics creates and registers Prometheus metrics for the LokiStack gateway.
func newMetrics(reg prometheus.Registerer) *metrics {
factory := promauto.With(reg)
return &metrics{
RequestsTotal: factory.NewCounterVec(
prometheus.CounterOpts{
Name: "lokistack_gateway_requests_total",
Help: "Total number of requests processed by the LokiStack gateway.",
},
[]string{"method", "route", "status_code"},
),
RequestDuration: factory.NewHistogramVec(
prometheus.HistogramOpts{
Name: "lokistack_gateway_request_duration_seconds",
Help: "Duration of requests processed by the LokiStack gateway.",
Buckets: []float64{.1, 1, 5, 9, 15, 30, 60, 120, 300},
},
[]string{"method", "route"},
),
RequestsInFlight: factory.NewGaugeVec(
prometheus.GaugeOpts{
Name: "lokistack_gateway_requests_in_flight",
Help: "Current number of requests being processed by the LokiStack gateway.",
},
[]string{"route"},
),
}
}

@ -0,0 +1,57 @@
package passthroughgateway
import (
"strings"
"testing"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
)
func TestMetricsRequestsTotal(t *testing.T) {
registry := prometheus.NewPedanticRegistry()
metrics := newMetrics(registry)
metrics.RequestsTotal.WithLabelValues("GET", "read", "200").Inc()
metrics.RequestsTotal.WithLabelValues("POST", "write", "201").Inc()
expected := `
# HELP lokistack_gateway_requests_total Total number of requests processed by the LokiStack gateway.
# TYPE lokistack_gateway_requests_total counter
lokistack_gateway_requests_total{method="GET",route="read",status_code="200"} 1
lokistack_gateway_requests_total{method="POST",route="write",status_code="201"} 1
`
err := testutil.CollectAndCompare(metrics.RequestsTotal, strings.NewReader(expected))
require.NoError(t, err)
}
func TestMetricsRequestDuration(t *testing.T) {
registry := prometheus.NewPedanticRegistry()
metrics := newMetrics(registry)
metrics.RequestDuration.WithLabelValues("GET", "/api/logs").Observe(0.5)
count := testutil.CollectAndCount(metrics.RequestDuration)
require.Equal(t, 1, count)
}
func TestMetricsRequestsInFlight(t *testing.T) {
registry := prometheus.NewPedanticRegistry()
metrics := newMetrics(registry)
metrics.RequestsInFlight.WithLabelValues("read").Inc()
metrics.RequestsInFlight.WithLabelValues("write").Inc()
metrics.RequestsInFlight.WithLabelValues("read").Dec()
expected := `
# HELP lokistack_gateway_requests_in_flight Current number of requests being processed by the LokiStack gateway.
# TYPE lokistack_gateway_requests_in_flight gauge
lokistack_gateway_requests_in_flight{route="read"} 0
lokistack_gateway_requests_in_flight{route="write"} 1
`
err := testutil.CollectAndCompare(metrics.RequestsInFlight, strings.NewReader(expected))
require.NoError(t, err)
}

@ -0,0 +1,154 @@
package passthroughgateway
import (
"net/http"
"net/http/httputil"
"net/url"
"slices"
"strconv"
"strings"
"time"
"github.com/go-logr/logr"
)
var lokiWritePaths = []string{
"/loki/api/v1/push",
"/api/prom/push",
"/otlp/v1/logs",
}
// lokiRouter directs requests to the appropriate Loki upstream (distributor or query-frontend).
type lokiRouter struct {
writeProxy *httputil.ReverseProxy
readProxy *httputil.ReverseProxy
logger logr.Logger
defaultTenant string
}
type responseWriter struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Flush() {
if f, ok := rw.ResponseWriter.(http.Flusher); ok {
f.Flush()
}
}
func (rw *responseWriter) Unwrap() http.ResponseWriter {
return rw.ResponseWriter
}
// NewLokiRouter creates a new router that directs requests to the appropriate upstream.
func NewLokiRouter(cfg *Config, logger logr.Logger) (*lokiRouter, error) {
transport, err := newTransport(cfg)
if err != nil {
return nil, err
}
writeProxy, err := newReverseProxy(cfg.Loki.DistributorEndpoint, transport, logger)
if err != nil {
return nil, err
}
readProxy, err := newReverseProxy(cfg.Loki.QueryFrontendEndpoint, transport, logger)
if err != nil {
return nil, err
}
return &lokiRouter{
writeProxy: writeProxy,
readProxy: readProxy,
logger: logger,
defaultTenant: cfg.DefaultTenant,
}, nil
}
func newTransport(cfg *Config) (*http.Transport, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.Proxy = nil
transport.ResponseHeaderTimeout = cfg.Loki.Timeout
tlsConfig, err := BuildUpstreamTLSConfig(cfg.TLSOptions(), cfg.Loki.CAFile, cfg.Loki.CertFile, cfg.Loki.KeyFile)
if err != nil {
return nil, err
}
transport.TLSClientConfig = tlsConfig
return transport, nil
}
func newReverseProxy(upstreamEndpoint string, transport *http.Transport, logger logr.Logger) (*httputil.ReverseProxy, error) {
target, err := url.Parse(upstreamEndpoint)
if err != nil {
return nil, err
}
proxy := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
pr.SetURL(target)
pr.Out.Host = target.Host
},
Transport: transport,
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error(err, "proxy error", "method", r.Method, "path", r.URL.Path)
w.WriteHeader(http.StatusBadGateway)
},
}
return proxy, nil
}
func isWritePath(path string) bool {
return slices.ContainsFunc(lokiWritePaths, func(writePath string) bool {
return strings.HasPrefix(path, writePath)
})
}
func (r *lokiRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Header.Get("X-Scope-OrgID") == "" {
if r.defaultTenant == "" {
r.logger.Error(nil, "missing required header", "header", "X-Scope-OrgID", "path", req.URL.Path, "method", req.Method)
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("Bad request: X-Scope-OrgID header is required"))
return
}
req.Header.Set("X-Scope-OrgID", r.defaultTenant)
}
if isWritePath(req.URL.Path) {
r.writeProxy.ServeHTTP(w, req)
return
}
r.readProxy.ServeHTTP(w, req)
}
// instrumentedHandler wraps an http.Handler with metrics instrumentation.
func instrumentedHandler(handler http.Handler, metrics *metrics) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
route := "read"
if isWritePath(r.URL.Path) {
route = "write"
}
inFlight := metrics.RequestsInFlight.WithLabelValues(route)
inFlight.Inc()
defer inFlight.Dec()
start := time.Now()
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
handler.ServeHTTP(wrapped, r)
duration := time.Since(start).Seconds()
metrics.RequestDuration.WithLabelValues(r.Method, route).Observe(duration)
metrics.RequestsTotal.WithLabelValues(r.Method, route, strconv.Itoa(wrapped.statusCode)).Inc()
})
}

@ -0,0 +1,117 @@
package passthroughgateway
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/testutil"
"github.com/stretchr/testify/require"
)
func BenchmarkIsWritePath(b *testing.B) {
benchmarks := []struct {
name string
path string
}{
{"write_loki_push", "/loki/api/v1/push"},
{"write_prom_push", "/api/prom/push"},
{"write_otlp", "/otlp/v1/logs"},
{"write_with_suffix", "/loki/api/v1/push?foo=bar"},
{"read_query", "/loki/api/v1/query"},
{"read_labels", "/loki/api/v1/labels"},
{"read_series", "/loki/api/v1/series"},
{"read_tail", "/loki/api/v1/tail"},
{"read_long_path", "/loki/api/v1/query_range?query=rate({app='test'}[5m])&start=1234567890&end=1234567899"},
}
for _, bm := range benchmarks {
b.Run(bm.name, func(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
_ = isWritePath(bm.path)
}
})
}
}
func TestInstrumentedHandler(t *testing.T) {
tt := []struct {
name string
method string
path string
route string
handlerStatus int
}{
{
name: "read request with GET",
method: http.MethodGet,
path: "/loki/api/v1/query",
route: "read",
handlerStatus: http.StatusOK,
},
{
name: "write request with POST",
method: http.MethodPost,
path: "/loki/api/v1/push",
route: "write",
handlerStatus: http.StatusNoContent,
},
{
name: "write request otlp",
method: http.MethodPost,
path: "/otlp/v1/logs",
route: "write",
handlerStatus: http.StatusOK,
},
{
name: "read request with error status",
method: http.MethodGet,
path: "/loki/api/v1/labels",
route: "read",
handlerStatus: http.StatusInternalServerError,
},
}
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
registry := prometheus.NewPedanticRegistry()
metrics := newMetrics(registry)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.handlerStatus)
})
instrumented := instrumentedHandler(handler, metrics)
req := httptest.NewRequest(tc.method, tc.path, nil)
rec := httptest.NewRecorder()
instrumented.ServeHTTP(rec, req)
require.Equal(t, tc.handlerStatus, rec.Code)
expectedTotal := fmt.Sprintf(`
# HELP lokistack_gateway_requests_total Total number of requests processed by the LokiStack gateway.
# TYPE lokistack_gateway_requests_total counter
lokistack_gateway_requests_total{method="%s",route="%s",status_code="%d"} 1
`, tc.method, tc.route, tc.handlerStatus)
err := testutil.CollectAndCompare(metrics.RequestsTotal, strings.NewReader(expectedTotal))
require.NoError(t, err)
durationCount := testutil.CollectAndCount(metrics.RequestDuration)
require.Equal(t, 1, durationCount)
expectedInFlight := fmt.Sprintf(`
# HELP lokistack_gateway_requests_in_flight Current number of requests being processed by the LokiStack gateway.
# TYPE lokistack_gateway_requests_in_flight gauge
lokistack_gateway_requests_in_flight{route="%s"} 0
`, tc.route)
err = testutil.CollectAndCompare(metrics.RequestsInFlight, strings.NewReader(expectedInFlight))
require.NoError(t, err)
})
}
}

@ -0,0 +1,177 @@
package passthroughgateway
import (
"context"
"errors"
"net/http"
"sync"
"sync/atomic"
"time"
"github.com/go-logr/logr"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
type Server struct {
config *Config
logger logr.Logger
metrics *metrics
proxyServer *http.Server
adminServer *http.Server
shuttingDown *atomic.Bool
}
func NewServer(cfg *Config, logger logr.Logger, reg prometheus.Registerer) (*Server, error) {
metrics := newMetrics(reg)
router, err := NewLokiRouter(cfg, logger)
if err != nil {
return nil, err
}
proxyHandler := instrumentedHandler(router, metrics)
var proxyWriteTimeout time.Duration
if cfg.Loki.Timeout > 0 {
proxyWriteTimeout = cfg.Loki.Timeout + 10*time.Second
}
proxyServer := &http.Server{
Addr: cfg.ListenAddr,
Handler: proxyHandler,
ReadTimeout: 30 * time.Second,
WriteTimeout: proxyWriteTimeout,
IdleTimeout: 120 * time.Second,
}
adminServer := &http.Server{
Addr: cfg.AdminAddr,
ReadTimeout: 10 * time.Second,
}
if cfg.TLSEnabled() {
tlsConfig, err := BuildServerTLSConfigWithClientAuth(cfg.TLSOptions(), cfg.TLSClientCAFile)
if err != nil {
return nil, err
}
proxyServer.TLSConfig = tlsConfig
adminTLSConfig, err := BuildServerTLSConfig(cfg.TLSOptions(), cfg.TLSClientCAFile)
if err != nil {
return nil, err
}
adminServer.TLSConfig = adminTLSConfig
}
adminMux := http.NewServeMux()
adminMux.Handle("/metrics", promhttp.HandlerFor(
reg.(prometheus.Gatherer),
promhttp.HandlerOpts{Registry: reg},
))
adminMux.HandleFunc("/live", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
server := &Server{
config: cfg,
logger: logger,
metrics: metrics,
shuttingDown: &atomic.Bool{},
}
adminMux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) {
if server.shuttingDown.Load() {
w.WriteHeader(http.StatusServiceUnavailable)
_, _ = w.Write([]byte("shutting down"))
return
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
})
adminServer.Handler = adminMux
adminServer.WriteTimeout = 10 * time.Second
server.proxyServer = proxyServer
server.adminServer = adminServer
return server, nil
}
func (s *Server) Run(ctx context.Context) error {
var wg sync.WaitGroup
errChan := make(chan error, 2)
wg.Go(func() {
var err error
if s.config.TLSEnabled() {
s.logger.Info("starting HTTPS admin server", "addr", s.config.AdminAddr)
err = s.adminServer.ListenAndServeTLS(s.config.TLSCertFile, s.config.TLSKeyFile)
} else {
s.logger.Info("starting HTTP admin server", "addr", s.config.AdminAddr)
err = s.adminServer.ListenAndServe()
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
})
wg.Go(func() {
var err error
if s.config.TLSEnabled() {
s.logger.Info("starting mTLS proxy server", "addr", s.config.ListenAddr)
err = s.proxyServer.ListenAndServeTLS(s.config.TLSCertFile, s.config.TLSKeyFile)
} else {
s.logger.Info("starting HTTP proxy server", "addr", s.config.ListenAddr)
err = s.proxyServer.ListenAndServe()
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
errChan <- err
}
})
var runErr error
select {
case <-ctx.Done():
s.logger.Info("shutting down servers")
runErr = s.Shutdown()
case err := <-errChan:
runErr = err
_ = s.Shutdown()
}
wg.Wait()
return runErr
}
func (s *Server) Shutdown() error {
s.shuttingDown.Store(true)
var (
wg sync.WaitGroup
proxyErr error
adminErr error
)
wg.Go(func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.proxyServer.Shutdown(ctx); err != nil {
s.logger.Error(err, "error shutting down proxy server")
proxyErr = err
}
})
wg.Go(func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := s.adminServer.Shutdown(ctx); err != nil {
s.logger.Error(err, "error shutting down admin server")
adminErr = err
}
})
wg.Wait()
return errors.Join(proxyErr, adminErr)
}

@ -0,0 +1,231 @@
package passthroughgateway
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"os"
"strings"
"github.com/ViaQ/logerr/v2/kverrors"
)
var errUnknownClientAuthType = errors.New("unknown client auth type")
type TLSConfig struct {
MinVersion string
MaxVersion string
CipherSuites []string
CurvePrefs []string
ClientAuth string
}
var tlsClientAuthTypes = map[string]tls.ClientAuthType{
"NoClientCert": tls.NoClientCert,
"RequestClientCert": tls.RequestClientCert,
"RequireAnyClientCert": tls.RequireAnyClientCert,
"VerifyClientCertIfGiven": tls.VerifyClientCertIfGiven,
"RequireAndVerifyClientCert": tls.RequireAndVerifyClientCert,
}
var tlsVersions = map[string]uint16{
"VersionTLS10": tls.VersionTLS10,
"VersionTLS11": tls.VersionTLS11,
"VersionTLS12": tls.VersionTLS12,
"VersionTLS13": tls.VersionTLS13,
}
var tlsCipherSuites = map[string]uint16{
// TLS 1.2 cipher suites
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256,
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256,
// TLS 1.3 cipher suites (always enabled when TLS 1.3 is used)
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256,
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384,
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256,
}
var tlsCurveIDs = map[string]tls.CurveID{
"X25519": tls.X25519,
"CurveP256": tls.CurveP256,
"CurveP384": tls.CurveP384,
"CurveP521": tls.CurveP521,
}
// BuildServerTLSConfig creates a TLS config for the server with optional mTLS.
func BuildServerTLSConfig(cfg *TLSConfig, clientCAFile string) (*tls.Config, error) {
tlsConfig, err := buildTLSConfig(cfg)
if err != nil {
return nil, err
}
if clientCAFile != "" {
caCert, err := os.ReadFile(clientCAFile)
if err != nil {
return nil, fmt.Errorf("failed to read client CA file: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, kverrors.New("failed to append client CA certs")
}
tlsConfig.ClientCAs = caCertPool
}
return tlsConfig, nil
}
// BuildServerTLSConfigWithClientAuth creates a TLS config for the server with optional client authentication.
func BuildServerTLSConfigWithClientAuth(cfg *TLSConfig, clientCAFile string) (*tls.Config, error) {
tlsConfig, err := BuildServerTLSConfig(cfg, clientCAFile)
if err != nil {
return nil, err
}
if clientCAFile != "" {
clientAuth, err := parseClientAuthType(cfg.ClientAuth)
if err != nil {
return nil, err
}
tlsConfig.ClientAuth = clientAuth
}
return tlsConfig, nil
}
// BuildUpstreamTLSConfig creates a TLS config for upstream connections with optional client certificate.
func BuildUpstreamTLSConfig(cfg *TLSConfig, upstreamCAFile, upstreamCertFile, upstreamKeyFile string) (*tls.Config, error) {
tlsConfig, err := buildTLSConfig(cfg)
if err != nil {
return nil, err
}
// Load upstream CA for server verification
if upstreamCAFile != "" {
caCert, err := os.ReadFile(upstreamCAFile)
if err != nil {
return nil, fmt.Errorf("failed to read upstream CA file: %w", err)
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCert) {
return nil, kverrors.New("failed to append upstream CA certs")
}
tlsConfig.RootCAs = caCertPool
}
// Load client certificate for upstream mTLS
if upstreamCertFile != "" && upstreamKeyFile != "" {
cert, err := tls.LoadX509KeyPair(upstreamCertFile, upstreamKeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load upstream client certificate: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}
return tlsConfig, nil
}
func parseClientAuthType(authType string) (tls.ClientAuthType, error) {
if authType == "" {
return tls.RequireAndVerifyClientCert, nil
}
auth, ok := tlsClientAuthTypes[authType]
if !ok {
return 0, fmt.Errorf("%w: %s", errUnknownClientAuthType, authType)
}
return auth, nil
}
func buildTLSConfig(cfg *TLSConfig) (*tls.Config, error) {
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
}
if cfg.MinVersion != "" {
minVersion, err := parseTLSVersion(cfg.MinVersion)
if err != nil {
return nil, fmt.Errorf("invalid minimum TLS version: %w", err)
}
tlsConfig.MinVersion = minVersion
}
if cfg.MaxVersion != "" {
maxVersion, err := parseTLSVersion(cfg.MaxVersion)
if err != nil {
return nil, fmt.Errorf("invalid maximum TLS version: %w", err)
}
tlsConfig.MaxVersion = maxVersion
}
if len(cfg.CipherSuites) > 0 {
cipherSuites, err := parseCipherSuites(cfg.CipherSuites)
if err != nil {
return nil, fmt.Errorf("invalid cipher suites: %w", err)
}
tlsConfig.CipherSuites = cipherSuites
}
if len(cfg.CurvePrefs) > 0 {
curvePrefs, err := parseCurvePreferences(cfg.CurvePrefs)
if err != nil {
return nil, fmt.Errorf("invalid curve preferences: %w", err)
}
tlsConfig.CurvePreferences = curvePrefs
}
return tlsConfig, nil
}
func parseTLSVersion(version string) (uint16, error) {
if version == "" {
return 0, nil
}
v, ok := tlsVersions[version]
if !ok {
return 0, kverrors.New("unknown TLS version", "version", version)
}
return v, nil
}
func parseCipherSuites(ciphers []string) ([]uint16, error) {
if len(ciphers) == 0 {
return nil, nil
}
result := make([]uint16, 0, len(ciphers))
for _, name := range ciphers {
id, ok := tlsCipherSuites[strings.TrimSpace(name)]
if !ok {
return nil, kverrors.New("unknown cipher suite", "cipher", name)
}
result = append(result, id)
}
return result, nil
}
func parseCurvePreferences(curves []string) ([]tls.CurveID, error) {
if len(curves) == 0 {
return nil, nil
}
result := make([]tls.CurveID, 0, len(curves))
for _, name := range curves {
id, ok := tlsCurveIDs[strings.TrimSpace(name)]
if !ok {
return nil, kverrors.New("unknown curve", "curve", name)
}
result = append(result, id)
}
return result, nil
}

@ -0,0 +1,307 @@
package passthroughgateway
import (
"crypto/tls"
"testing"
"github.com/stretchr/testify/require"
)
func TestBuildTLSConfig(t *testing.T) {
tt := []struct {
desc string
cfg *TLSConfig
wantErr bool
errContains string
validate func(t *testing.T, cfg *tls.Config)
}{
{
desc: "default config with TLS 1.2 min version",
cfg: &TLSConfig{},
validate: func(t *testing.T, cfg *tls.Config) {
require.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
},
},
{
desc: "custom min and maximum version",
cfg: &TLSConfig{
MinVersion: "VersionTLS12",
MaxVersion: "VersionTLS13",
},
validate: func(t *testing.T, cfg *tls.Config) {
require.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion)
require.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion)
},
},
{
desc: "invalid min version",
cfg: &TLSConfig{
MinVersion: "TLS99",
},
wantErr: true,
errContains: "invalid minimum TLS version",
},
{
desc: "invalid max version",
cfg: &TLSConfig{
MaxVersion: "invalid",
},
wantErr: true,
errContains: "invalid maximum TLS version",
},
{
desc: "valid cipher suites",
cfg: &TLSConfig{
CipherSuites: []string{
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
},
},
validate: func(t *testing.T, cfg *tls.Config) {
require.Len(t, cfg.CipherSuites, 2)
require.Contains(t, cfg.CipherSuites, uint16(tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256))
require.Contains(t, cfg.CipherSuites, uint16(tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384))
},
},
{
desc: "invalid cipher suite",
cfg: &TLSConfig{
CipherSuites: []string{"INVALID_CIPHER"},
},
wantErr: true,
errContains: "invalid cipher suites",
},
{
desc: "valid curve preferences",
cfg: &TLSConfig{
CurvePrefs: []string{"X25519", "CurveP256"},
},
validate: func(t *testing.T, cfg *tls.Config) {
require.Len(t, cfg.CurvePreferences, 2)
require.Contains(t, cfg.CurvePreferences, tls.X25519)
require.Contains(t, cfg.CurvePreferences, tls.CurveP256)
},
},
{
desc: "invalid curve preference",
cfg: &TLSConfig{
CurvePrefs: []string{"INVALID_CURVE"},
},
wantErr: true,
errContains: "invalid curve preferences",
},
}
for _, tc := range tt {
t.Run(tc.desc, func(t *testing.T) {
cfg, err := buildTLSConfig(tc.cfg)
if tc.wantErr {
require.Error(t, err)
require.Contains(t, err.Error(), tc.errContains)
return
}
require.NoError(t, err)
if tc.validate != nil {
tc.validate(t, cfg)
}
})
}
}
func TestParseTLSVersion(t *testing.T) {
tt := []struct {
desc string
version string
want uint16
wantErr bool
}{
{
desc: "empty version returns zero",
version: "",
want: 0,
},
{
desc: "TLS 1.0",
version: "VersionTLS10",
want: tls.VersionTLS10,
},
{
desc: "TLS 1.2",
version: "VersionTLS12",
want: tls.VersionTLS12,
},
{
desc: "TLS 1.3",
version: "VersionTLS13",
want: tls.VersionTLS13,
},
{
desc: "unknown version",
version: "TLS99",
wantErr: true,
},
}
for _, tc := range tt {
t.Run(tc.desc, func(t *testing.T) {
got, err := parseTLSVersion(tc.version)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
}
func TestParseClientAuthType(t *testing.T) {
tt := []struct {
desc string
authType string
want tls.ClientAuthType
wantErr bool
}{
{
desc: "empty defaults to RequireAndVerifyClientCert",
authType: "",
want: tls.RequireAndVerifyClientCert,
},
{
desc: "NoClientCert",
authType: "NoClientCert",
want: tls.NoClientCert,
},
{
desc: "RequestClientCert",
authType: "RequestClientCert",
want: tls.RequestClientCert,
},
{
desc: "RequireAnyClientCert",
authType: "RequireAnyClientCert",
want: tls.RequireAnyClientCert,
},
{
desc: "VerifyClientCertIfGiven",
authType: "VerifyClientCertIfGiven",
want: tls.VerifyClientCertIfGiven,
},
{
desc: "RequireAndVerifyClientCert",
authType: "RequireAndVerifyClientCert",
want: tls.RequireAndVerifyClientCert,
},
{
desc: "unknown auth type",
authType: "InvalidAuthType",
wantErr: true,
},
}
for _, tc := range tt {
t.Run(tc.desc, func(t *testing.T) {
got, err := parseClientAuthType(tc.authType)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
}
func TestParseCipherSuites(t *testing.T) {
tt := []struct {
desc string
ciphers []string
want []uint16
wantErr bool
}{
{
desc: "empty returns nil",
ciphers: []string{},
want: nil,
},
{
desc: "single cipher",
ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"},
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
},
{
desc: "multiple ciphers",
ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256"},
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_AES_128_GCM_SHA256},
},
{
desc: "cipher with whitespace",
ciphers: []string{" TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 "},
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
},
{
desc: "unknown cipher",
ciphers: []string{"UNKNOWN_CIPHER"},
wantErr: true,
},
}
for _, tc := range tt {
t.Run(tc.desc, func(t *testing.T) {
got, err := parseCipherSuites(tc.ciphers)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
}
func TestParseCurvePreferences(t *testing.T) {
tt := []struct {
desc string
curves []string
want []tls.CurveID
wantErr bool
}{
{
desc: "empty returns nil",
curves: []string{},
want: nil,
},
{
desc: "X25519",
curves: []string{"X25519"},
want: []tls.CurveID{tls.X25519},
},
{
desc: "multiple curves",
curves: []string{"X25519", "CurveP256", "CurveP384"},
want: []tls.CurveID{tls.X25519, tls.CurveP256, tls.CurveP384},
},
{
desc: "curve with whitespace",
curves: []string{" CurveP256 "},
want: []tls.CurveID{tls.CurveP256},
},
{
desc: "unknown curve",
curves: []string{"UNKNOWN_CURVE"},
wantErr: true,
},
}
for _, tc := range tt {
t.Run(tc.desc, func(t *testing.T) {
got, err := parseCurvePreferences(tc.curves)
if tc.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
require.Equal(t, tc.want, got)
})
}
}

@ -0,0 +1,26 @@
# Build the passthrough-gateway binary
FROM golang:1.26.3 as builder
WORKDIR /workspace
# Copy the Go Modules manifests
COPY api/ api/
COPY go.mod go.sum ./
# cache deps before building and copying source so that we don't need to re-download as much
# and so that source changes don't invalidate our downloaded layer
RUN go mod download
# Copy the go source
COPY cmd/passthrough-gateway/main.go cmd/passthrough-gateway/main.go
COPY internal/ internal/
# Build
RUN CGO_ENABLED=0 GOOS=linux GO111MODULE=on go build -mod=readonly -o passthrough-gateway cmd/passthrough-gateway/main.go
# Use distroless as minimal base image to package the passthrough-gateway binary
# Refer to https://github.com/GoogleContainerTools/distroless for more details
FROM gcr.io/distroless/static:nonroot
WORKDIR /
COPY --from=builder /workspace/passthrough-gateway .
USER 65532:65532
ENTRYPOINT ["/passthrough-gateway"]
Loading…
Cancel
Save