mirror of https://github.com/grafana/loki
feat(operator): Add passthrough-gateway component (#20382)
Signed-off-by: Joao Marcal <jmarcal@redhat.com> Co-authored-by: Robert Jacob <rojacob@redhat.com>pull/21989/merge
parent
98d43a23c1
commit
bf06e45bcf
@ -0,0 +1,60 @@ |
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
"flag" |
||||
"os" |
||||
"os/signal" |
||||
"syscall" |
||||
|
||||
"github.com/ViaQ/logerr/v2/log" |
||||
"github.com/prometheus/client_golang/prometheus" |
||||
"github.com/prometheus/client_golang/prometheus/collectors" |
||||
|
||||
"github.com/grafana/loki/operator/internal/passthroughgateway" |
||||
) |
||||
|
||||
func main() { |
||||
logger := log.NewLogger("lokistack-gateway") |
||||
|
||||
cfg := &passthroughgateway.Config{} |
||||
f := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) |
||||
cfg.RegisterFlags(f) |
||||
|
||||
if err := f.Parse(os.Args[1:]); err != nil { |
||||
logger.Error(err, "failed to parse flags") |
||||
os.Exit(1) |
||||
} |
||||
|
||||
if err := cfg.Validate(); err != nil { |
||||
logger.Error(err, "invalid configuration") |
||||
os.Exit(1) |
||||
} |
||||
|
||||
logger.Info("starting gateway", |
||||
"listen-addr", cfg.ListenAddr, |
||||
"admin-addr", cfg.AdminAddr, |
||||
"loki-distributor-endpoint", cfg.Loki.DistributorEndpoint, |
||||
"loki-query-frontend-endpoint", cfg.Loki.QueryFrontendEndpoint, |
||||
) |
||||
|
||||
reg := prometheus.NewRegistry() |
||||
reg.MustRegister(collectors.NewGoCollector()) |
||||
reg.MustRegister(collectors.NewProcessCollector(collectors.ProcessCollectorOpts{})) |
||||
|
||||
server, err := passthroughgateway.NewServer(cfg, logger, reg) |
||||
if err != nil { |
||||
logger.Error(err, "failed to create server") |
||||
os.Exit(1) |
||||
} |
||||
|
||||
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) |
||||
defer cancel() |
||||
|
||||
if err := server.Run(ctx); err != nil { |
||||
logger.Error(err, "server error") |
||||
os.Exit(1) |
||||
} |
||||
|
||||
logger.Info("lokistack gateway stopped") |
||||
} |
||||
@ -0,0 +1,126 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"flag" |
||||
"net/url" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/ViaQ/logerr/v2/kverrors" |
||||
) |
||||
|
||||
type LokiConfig struct { |
||||
DistributorEndpoint string |
||||
QueryFrontendEndpoint string |
||||
Timeout time.Duration |
||||
CAFile string |
||||
CertFile string |
||||
KeyFile string |
||||
} |
||||
|
||||
type Config struct { |
||||
ListenAddr string |
||||
AdminAddr string |
||||
DefaultTenant string |
||||
Loki LokiConfig |
||||
// Server TLS configuration
|
||||
TLSCertFile string |
||||
TLSKeyFile string |
||||
// Client TLS configuration
|
||||
TLSClientCAFile string |
||||
TLSClientAuth string |
||||
// Generic TLS configuration
|
||||
TLSMinVersion string |
||||
TLSMaxVersion string |
||||
TLSCipherSuites StringSlice |
||||
TLSCurvePrefs StringSlice |
||||
} |
||||
|
||||
// StringSlice is a custom flag type for comma-separated strings.
|
||||
type StringSlice []string |
||||
|
||||
func (s *StringSlice) String() string { |
||||
return strings.Join(*s, ",") |
||||
} |
||||
|
||||
func (s *StringSlice) Set(value string) error { |
||||
if value == "" { |
||||
return nil |
||||
} |
||||
*s = strings.Split(value, ",") |
||||
return nil |
||||
} |
||||
|
||||
// TLSOptions returns the TLS configuration options.
|
||||
func (c *Config) TLSOptions() *TLSConfig { |
||||
return &TLSConfig{ |
||||
MinVersion: c.TLSMinVersion, |
||||
MaxVersion: c.TLSMaxVersion, |
||||
CipherSuites: c.TLSCipherSuites, |
||||
CurvePrefs: c.TLSCurvePrefs, |
||||
ClientAuth: c.TLSClientAuth, |
||||
} |
||||
} |
||||
|
||||
// RegisterFlags registers the configuration flags.
|
||||
func (c *Config) RegisterFlags(f *flag.FlagSet) { |
||||
f.StringVar(&c.ListenAddr, "listen-addr", ":8080", "Address for the server to listen on.") |
||||
f.StringVar(&c.AdminAddr, "admin-addr", ":9090", "Address for admin endpoints (metrics, health, readiness).") |
||||
f.StringVar(&c.DefaultTenant, "default-tenant", "", "Default tenant ID to use when X-Scope-OrgID header is not set. If empty, requests without X-Scope-OrgID are rejected.") |
||||
// Loki upstream configuration flags
|
||||
f.StringVar(&c.Loki.DistributorEndpoint, "loki-distributor-endpoint", "", "Upstream URL of the Loki distributor (write path).") |
||||
f.StringVar(&c.Loki.QueryFrontendEndpoint, "loki-query-frontend-endpoint", "", "Upstream URL of the Loki query frontend (read path).") |
||||
f.DurationVar(&c.Loki.Timeout, "loki-timeout", 60*time.Second, "Timeout for upstream Loki requests. Set to 0 for no timeout.") |
||||
f.StringVar(&c.Loki.CAFile, "loki-ca-file", "", "Path to the CA certificate for verifying the Loki server.") |
||||
f.StringVar(&c.Loki.CertFile, "loki-cert-file", "", "Path to the client certificate for Loki mTLS.") |
||||
f.StringVar(&c.Loki.KeyFile, "loki-key-file", "", "Path to the client private key for Loki mTLS.") |
||||
// Server TLS configuration flags
|
||||
f.StringVar(&c.TLSCertFile, "tls-cert-file", "", "Path to the server TLS certificate file.") |
||||
f.StringVar(&c.TLSKeyFile, "tls-key-file", "", "Path to the server TLS private key file.") |
||||
// Client TLS configuration flags
|
||||
f.StringVar(&c.TLSClientCAFile, "tls-client-ca-file", "", "Path to the CA certificate for verifying client certificates.") |
||||
f.StringVar(&c.TLSClientAuth, "tls-client-auth", "NoClientCert", "Client certificate auth mode (NoClientCert, RequestClientCert, RequireAnyClientCert, VerifyClientCertIfGiven, RequireAndVerifyClientCert).") |
||||
// Generic TLS configuration flags
|
||||
f.StringVar(&c.TLSMinVersion, "tls-min-version", "VersionTLS12", "Minimum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13).") |
||||
f.StringVar(&c.TLSMaxVersion, "tls-max-version", "", "Maximum TLS version (VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13). Empty means no maximum.") |
||||
f.Var(&c.TLSCipherSuites, "tls-cipher-suites", "Comma-separated list of TLS cipher suites.") |
||||
f.Var(&c.TLSCurvePrefs, "tls-curve-preferences", "Comma-separated list of curve preferences (X25519, CurveP256, CurveP384, CurveP521).") |
||||
} |
||||
|
||||
// Validate checks that all required configuration is provided and valid.
|
||||
func (c *Config) Validate() error { |
||||
if c.Loki.DistributorEndpoint == "" { |
||||
return kverrors.New("-loki-distributor-endpoint is required") |
||||
} |
||||
|
||||
u, err := url.Parse(c.Loki.DistributorEndpoint) |
||||
if err != nil { |
||||
return kverrors.New("-loki-distributor-endpoint is not a valid URL", "endpoint", c.Loki.DistributorEndpoint, "err", err) |
||||
} |
||||
if u.Host == "" { |
||||
return kverrors.New("-loki-distributor-endpoint is missing a host", "endpoint", c.Loki.DistributorEndpoint) |
||||
} |
||||
|
||||
if c.Loki.QueryFrontendEndpoint == "" { |
||||
return kverrors.New("-loki-query-frontend-endpoint is required") |
||||
} |
||||
|
||||
u, err = url.Parse(c.Loki.QueryFrontendEndpoint) |
||||
if err != nil { |
||||
return kverrors.New("-loki-query-frontend-endpoint is not a valid URL", "endpoint", c.Loki.QueryFrontendEndpoint, "err", err) |
||||
} |
||||
if u.Host == "" { |
||||
return kverrors.New("-loki-query-frontend-endpoint is missing a host", "endpoint", c.Loki.QueryFrontendEndpoint) |
||||
} |
||||
|
||||
if c.DefaultTenant != "" && strings.ContainsAny(c.DefaultTenant, "\n") { |
||||
return kverrors.New("-default-tenant must not contain newline characters") |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// TLSEnabled returns true if TLS certificates are configured.
|
||||
func (c *Config) TLSEnabled() bool { |
||||
return c.TLSCertFile != "" && c.TLSKeyFile != "" |
||||
} |
||||
@ -0,0 +1,43 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"github.com/prometheus/client_golang/prometheus" |
||||
"github.com/prometheus/client_golang/prometheus/promauto" |
||||
) |
||||
|
||||
// metrics holds Prometheus metrics for the LokiStack gateway.
|
||||
type metrics struct { |
||||
RequestsTotal *prometheus.CounterVec |
||||
RequestDuration *prometheus.HistogramVec |
||||
RequestsInFlight *prometheus.GaugeVec |
||||
} |
||||
|
||||
// newMetrics creates and registers Prometheus metrics for the LokiStack gateway.
|
||||
func newMetrics(reg prometheus.Registerer) *metrics { |
||||
factory := promauto.With(reg) |
||||
|
||||
return &metrics{ |
||||
RequestsTotal: factory.NewCounterVec( |
||||
prometheus.CounterOpts{ |
||||
Name: "lokistack_gateway_requests_total", |
||||
Help: "Total number of requests processed by the LokiStack gateway.", |
||||
}, |
||||
[]string{"method", "route", "status_code"}, |
||||
), |
||||
RequestDuration: factory.NewHistogramVec( |
||||
prometheus.HistogramOpts{ |
||||
Name: "lokistack_gateway_request_duration_seconds", |
||||
Help: "Duration of requests processed by the LokiStack gateway.", |
||||
Buckets: []float64{.1, 1, 5, 9, 15, 30, 60, 120, 300}, |
||||
}, |
||||
[]string{"method", "route"}, |
||||
), |
||||
RequestsInFlight: factory.NewGaugeVec( |
||||
prometheus.GaugeOpts{ |
||||
Name: "lokistack_gateway_requests_in_flight", |
||||
Help: "Current number of requests being processed by the LokiStack gateway.", |
||||
}, |
||||
[]string{"route"}, |
||||
), |
||||
} |
||||
} |
||||
@ -0,0 +1,57 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/prometheus/client_golang/prometheus" |
||||
"github.com/prometheus/client_golang/prometheus/testutil" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestMetricsRequestsTotal(t *testing.T) { |
||||
registry := prometheus.NewPedanticRegistry() |
||||
metrics := newMetrics(registry) |
||||
|
||||
metrics.RequestsTotal.WithLabelValues("GET", "read", "200").Inc() |
||||
metrics.RequestsTotal.WithLabelValues("POST", "write", "201").Inc() |
||||
|
||||
expected := ` |
||||
# HELP lokistack_gateway_requests_total Total number of requests processed by the LokiStack gateway. |
||||
# TYPE lokistack_gateway_requests_total counter |
||||
lokistack_gateway_requests_total{method="GET",route="read",status_code="200"} 1 |
||||
lokistack_gateway_requests_total{method="POST",route="write",status_code="201"} 1 |
||||
` |
||||
|
||||
err := testutil.CollectAndCompare(metrics.RequestsTotal, strings.NewReader(expected)) |
||||
require.NoError(t, err) |
||||
} |
||||
|
||||
func TestMetricsRequestDuration(t *testing.T) { |
||||
registry := prometheus.NewPedanticRegistry() |
||||
metrics := newMetrics(registry) |
||||
|
||||
metrics.RequestDuration.WithLabelValues("GET", "/api/logs").Observe(0.5) |
||||
|
||||
count := testutil.CollectAndCount(metrics.RequestDuration) |
||||
require.Equal(t, 1, count) |
||||
} |
||||
|
||||
func TestMetricsRequestsInFlight(t *testing.T) { |
||||
registry := prometheus.NewPedanticRegistry() |
||||
metrics := newMetrics(registry) |
||||
|
||||
metrics.RequestsInFlight.WithLabelValues("read").Inc() |
||||
metrics.RequestsInFlight.WithLabelValues("write").Inc() |
||||
metrics.RequestsInFlight.WithLabelValues("read").Dec() |
||||
|
||||
expected := ` |
||||
# HELP lokistack_gateway_requests_in_flight Current number of requests being processed by the LokiStack gateway. |
||||
# TYPE lokistack_gateway_requests_in_flight gauge |
||||
lokistack_gateway_requests_in_flight{route="read"} 0 |
||||
lokistack_gateway_requests_in_flight{route="write"} 1 |
||||
` |
||||
|
||||
err := testutil.CollectAndCompare(metrics.RequestsInFlight, strings.NewReader(expected)) |
||||
require.NoError(t, err) |
||||
} |
||||
@ -0,0 +1,154 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"net/http" |
||||
"net/http/httputil" |
||||
"net/url" |
||||
"slices" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/go-logr/logr" |
||||
) |
||||
|
||||
var lokiWritePaths = []string{ |
||||
"/loki/api/v1/push", |
||||
"/api/prom/push", |
||||
"/otlp/v1/logs", |
||||
} |
||||
|
||||
// lokiRouter directs requests to the appropriate Loki upstream (distributor or query-frontend).
|
||||
type lokiRouter struct { |
||||
writeProxy *httputil.ReverseProxy |
||||
readProxy *httputil.ReverseProxy |
||||
logger logr.Logger |
||||
defaultTenant string |
||||
} |
||||
|
||||
type responseWriter struct { |
||||
http.ResponseWriter |
||||
statusCode int |
||||
} |
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) { |
||||
rw.statusCode = code |
||||
rw.ResponseWriter.WriteHeader(code) |
||||
} |
||||
|
||||
func (rw *responseWriter) Flush() { |
||||
if f, ok := rw.ResponseWriter.(http.Flusher); ok { |
||||
f.Flush() |
||||
} |
||||
} |
||||
|
||||
func (rw *responseWriter) Unwrap() http.ResponseWriter { |
||||
return rw.ResponseWriter |
||||
} |
||||
|
||||
// NewLokiRouter creates a new router that directs requests to the appropriate upstream.
|
||||
func NewLokiRouter(cfg *Config, logger logr.Logger) (*lokiRouter, error) { |
||||
transport, err := newTransport(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
writeProxy, err := newReverseProxy(cfg.Loki.DistributorEndpoint, transport, logger) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
readProxy, err := newReverseProxy(cfg.Loki.QueryFrontendEndpoint, transport, logger) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &lokiRouter{ |
||||
writeProxy: writeProxy, |
||||
readProxy: readProxy, |
||||
logger: logger, |
||||
defaultTenant: cfg.DefaultTenant, |
||||
}, nil |
||||
} |
||||
|
||||
func newTransport(cfg *Config) (*http.Transport, error) { |
||||
transport := http.DefaultTransport.(*http.Transport).Clone() |
||||
transport.Proxy = nil |
||||
transport.ResponseHeaderTimeout = cfg.Loki.Timeout |
||||
|
||||
tlsConfig, err := BuildUpstreamTLSConfig(cfg.TLSOptions(), cfg.Loki.CAFile, cfg.Loki.CertFile, cfg.Loki.KeyFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
transport.TLSClientConfig = tlsConfig |
||||
|
||||
return transport, nil |
||||
} |
||||
|
||||
func newReverseProxy(upstreamEndpoint string, transport *http.Transport, logger logr.Logger) (*httputil.ReverseProxy, error) { |
||||
target, err := url.Parse(upstreamEndpoint) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
proxy := &httputil.ReverseProxy{ |
||||
Rewrite: func(pr *httputil.ProxyRequest) { |
||||
pr.SetURL(target) |
||||
pr.Out.Host = target.Host |
||||
}, |
||||
Transport: transport, |
||||
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { |
||||
logger.Error(err, "proxy error", "method", r.Method, "path", r.URL.Path) |
||||
w.WriteHeader(http.StatusBadGateway) |
||||
}, |
||||
} |
||||
|
||||
return proxy, nil |
||||
} |
||||
|
||||
func isWritePath(path string) bool { |
||||
return slices.ContainsFunc(lokiWritePaths, func(writePath string) bool { |
||||
return strings.HasPrefix(path, writePath) |
||||
}) |
||||
} |
||||
|
||||
func (r *lokiRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
||||
if req.Header.Get("X-Scope-OrgID") == "" { |
||||
if r.defaultTenant == "" { |
||||
r.logger.Error(nil, "missing required header", "header", "X-Scope-OrgID", "path", req.URL.Path, "method", req.Method) |
||||
w.WriteHeader(http.StatusBadRequest) |
||||
_, _ = w.Write([]byte("Bad request: X-Scope-OrgID header is required")) |
||||
return |
||||
} |
||||
req.Header.Set("X-Scope-OrgID", r.defaultTenant) |
||||
} |
||||
|
||||
if isWritePath(req.URL.Path) { |
||||
r.writeProxy.ServeHTTP(w, req) |
||||
return |
||||
} |
||||
r.readProxy.ServeHTTP(w, req) |
||||
} |
||||
|
||||
// instrumentedHandler wraps an http.Handler with metrics instrumentation.
|
||||
func instrumentedHandler(handler http.Handler, metrics *metrics) http.Handler { |
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
route := "read" |
||||
if isWritePath(r.URL.Path) { |
||||
route = "write" |
||||
} |
||||
|
||||
inFlight := metrics.RequestsInFlight.WithLabelValues(route) |
||||
inFlight.Inc() |
||||
defer inFlight.Dec() |
||||
|
||||
start := time.Now() |
||||
wrapped := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK} |
||||
|
||||
handler.ServeHTTP(wrapped, r) |
||||
|
||||
duration := time.Since(start).Seconds() |
||||
metrics.RequestDuration.WithLabelValues(r.Method, route).Observe(duration) |
||||
metrics.RequestsTotal.WithLabelValues(r.Method, route, strconv.Itoa(wrapped.statusCode)).Inc() |
||||
}) |
||||
} |
||||
@ -0,0 +1,117 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"fmt" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/prometheus/client_golang/prometheus" |
||||
"github.com/prometheus/client_golang/prometheus/testutil" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func BenchmarkIsWritePath(b *testing.B) { |
||||
benchmarks := []struct { |
||||
name string |
||||
path string |
||||
}{ |
||||
{"write_loki_push", "/loki/api/v1/push"}, |
||||
{"write_prom_push", "/api/prom/push"}, |
||||
{"write_otlp", "/otlp/v1/logs"}, |
||||
{"write_with_suffix", "/loki/api/v1/push?foo=bar"}, |
||||
{"read_query", "/loki/api/v1/query"}, |
||||
{"read_labels", "/loki/api/v1/labels"}, |
||||
{"read_series", "/loki/api/v1/series"}, |
||||
{"read_tail", "/loki/api/v1/tail"}, |
||||
{"read_long_path", "/loki/api/v1/query_range?query=rate({app='test'}[5m])&start=1234567890&end=1234567899"}, |
||||
} |
||||
|
||||
for _, bm := range benchmarks { |
||||
b.Run(bm.name, func(b *testing.B) { |
||||
b.ReportAllocs() |
||||
for b.Loop() { |
||||
_ = isWritePath(bm.path) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestInstrumentedHandler(t *testing.T) { |
||||
tt := []struct { |
||||
name string |
||||
method string |
||||
path string |
||||
route string |
||||
handlerStatus int |
||||
}{ |
||||
{ |
||||
name: "read request with GET", |
||||
method: http.MethodGet, |
||||
path: "/loki/api/v1/query", |
||||
route: "read", |
||||
handlerStatus: http.StatusOK, |
||||
}, |
||||
{ |
||||
name: "write request with POST", |
||||
method: http.MethodPost, |
||||
path: "/loki/api/v1/push", |
||||
route: "write", |
||||
handlerStatus: http.StatusNoContent, |
||||
}, |
||||
{ |
||||
name: "write request otlp", |
||||
method: http.MethodPost, |
||||
path: "/otlp/v1/logs", |
||||
route: "write", |
||||
handlerStatus: http.StatusOK, |
||||
}, |
||||
{ |
||||
name: "read request with error status", |
||||
method: http.MethodGet, |
||||
path: "/loki/api/v1/labels", |
||||
route: "read", |
||||
handlerStatus: http.StatusInternalServerError, |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tt { |
||||
t.Run(tc.name, func(t *testing.T) { |
||||
registry := prometheus.NewPedanticRegistry() |
||||
metrics := newMetrics(registry) |
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
w.WriteHeader(tc.handlerStatus) |
||||
}) |
||||
|
||||
instrumented := instrumentedHandler(handler, metrics) |
||||
|
||||
req := httptest.NewRequest(tc.method, tc.path, nil) |
||||
rec := httptest.NewRecorder() |
||||
|
||||
instrumented.ServeHTTP(rec, req) |
||||
|
||||
require.Equal(t, tc.handlerStatus, rec.Code) |
||||
|
||||
expectedTotal := fmt.Sprintf(` |
||||
# HELP lokistack_gateway_requests_total Total number of requests processed by the LokiStack gateway. |
||||
# TYPE lokistack_gateway_requests_total counter |
||||
lokistack_gateway_requests_total{method="%s",route="%s",status_code="%d"} 1 |
||||
`, tc.method, tc.route, tc.handlerStatus) |
||||
err := testutil.CollectAndCompare(metrics.RequestsTotal, strings.NewReader(expectedTotal)) |
||||
require.NoError(t, err) |
||||
|
||||
durationCount := testutil.CollectAndCount(metrics.RequestDuration) |
||||
require.Equal(t, 1, durationCount) |
||||
|
||||
expectedInFlight := fmt.Sprintf(` |
||||
# HELP lokistack_gateway_requests_in_flight Current number of requests being processed by the LokiStack gateway. |
||||
# TYPE lokistack_gateway_requests_in_flight gauge |
||||
lokistack_gateway_requests_in_flight{route="%s"} 0 |
||||
`, tc.route) |
||||
err = testutil.CollectAndCompare(metrics.RequestsInFlight, strings.NewReader(expectedInFlight)) |
||||
require.NoError(t, err) |
||||
}) |
||||
} |
||||
} |
||||
@ -0,0 +1,177 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"net/http" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"github.com/go-logr/logr" |
||||
"github.com/prometheus/client_golang/prometheus" |
||||
"github.com/prometheus/client_golang/prometheus/promhttp" |
||||
) |
||||
|
||||
type Server struct { |
||||
config *Config |
||||
logger logr.Logger |
||||
metrics *metrics |
||||
proxyServer *http.Server |
||||
adminServer *http.Server |
||||
shuttingDown *atomic.Bool |
||||
} |
||||
|
||||
func NewServer(cfg *Config, logger logr.Logger, reg prometheus.Registerer) (*Server, error) { |
||||
metrics := newMetrics(reg) |
||||
|
||||
router, err := NewLokiRouter(cfg, logger) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
proxyHandler := instrumentedHandler(router, metrics) |
||||
|
||||
var proxyWriteTimeout time.Duration |
||||
if cfg.Loki.Timeout > 0 { |
||||
proxyWriteTimeout = cfg.Loki.Timeout + 10*time.Second |
||||
} |
||||
|
||||
proxyServer := &http.Server{ |
||||
Addr: cfg.ListenAddr, |
||||
Handler: proxyHandler, |
||||
ReadTimeout: 30 * time.Second, |
||||
WriteTimeout: proxyWriteTimeout, |
||||
IdleTimeout: 120 * time.Second, |
||||
} |
||||
|
||||
adminServer := &http.Server{ |
||||
Addr: cfg.AdminAddr, |
||||
ReadTimeout: 10 * time.Second, |
||||
} |
||||
|
||||
if cfg.TLSEnabled() { |
||||
tlsConfig, err := BuildServerTLSConfigWithClientAuth(cfg.TLSOptions(), cfg.TLSClientCAFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
proxyServer.TLSConfig = tlsConfig |
||||
|
||||
adminTLSConfig, err := BuildServerTLSConfig(cfg.TLSOptions(), cfg.TLSClientCAFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
adminServer.TLSConfig = adminTLSConfig |
||||
} |
||||
|
||||
adminMux := http.NewServeMux() |
||||
adminMux.Handle("/metrics", promhttp.HandlerFor( |
||||
reg.(prometheus.Gatherer), |
||||
promhttp.HandlerOpts{Registry: reg}, |
||||
)) |
||||
adminMux.HandleFunc("/live", func(w http.ResponseWriter, r *http.Request) { |
||||
w.WriteHeader(http.StatusOK) |
||||
_, _ = w.Write([]byte("ok")) |
||||
}) |
||||
server := &Server{ |
||||
config: cfg, |
||||
logger: logger, |
||||
metrics: metrics, |
||||
shuttingDown: &atomic.Bool{}, |
||||
} |
||||
|
||||
adminMux.HandleFunc("/ready", func(w http.ResponseWriter, r *http.Request) { |
||||
if server.shuttingDown.Load() { |
||||
w.WriteHeader(http.StatusServiceUnavailable) |
||||
_, _ = w.Write([]byte("shutting down")) |
||||
return |
||||
} |
||||
w.WriteHeader(http.StatusOK) |
||||
_, _ = w.Write([]byte("ok")) |
||||
}) |
||||
|
||||
adminServer.Handler = adminMux |
||||
adminServer.WriteTimeout = 10 * time.Second |
||||
|
||||
server.proxyServer = proxyServer |
||||
server.adminServer = adminServer |
||||
|
||||
return server, nil |
||||
} |
||||
|
||||
func (s *Server) Run(ctx context.Context) error { |
||||
var wg sync.WaitGroup |
||||
errChan := make(chan error, 2) |
||||
|
||||
wg.Go(func() { |
||||
var err error |
||||
if s.config.TLSEnabled() { |
||||
s.logger.Info("starting HTTPS admin server", "addr", s.config.AdminAddr) |
||||
err = s.adminServer.ListenAndServeTLS(s.config.TLSCertFile, s.config.TLSKeyFile) |
||||
} else { |
||||
s.logger.Info("starting HTTP admin server", "addr", s.config.AdminAddr) |
||||
err = s.adminServer.ListenAndServe() |
||||
} |
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) { |
||||
errChan <- err |
||||
} |
||||
}) |
||||
|
||||
wg.Go(func() { |
||||
var err error |
||||
if s.config.TLSEnabled() { |
||||
s.logger.Info("starting mTLS proxy server", "addr", s.config.ListenAddr) |
||||
err = s.proxyServer.ListenAndServeTLS(s.config.TLSCertFile, s.config.TLSKeyFile) |
||||
} else { |
||||
s.logger.Info("starting HTTP proxy server", "addr", s.config.ListenAddr) |
||||
err = s.proxyServer.ListenAndServe() |
||||
} |
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) { |
||||
errChan <- err |
||||
} |
||||
}) |
||||
|
||||
var runErr error |
||||
select { |
||||
case <-ctx.Done(): |
||||
s.logger.Info("shutting down servers") |
||||
runErr = s.Shutdown() |
||||
case err := <-errChan: |
||||
runErr = err |
||||
_ = s.Shutdown() |
||||
} |
||||
|
||||
wg.Wait() |
||||
return runErr |
||||
} |
||||
|
||||
func (s *Server) Shutdown() error { |
||||
s.shuttingDown.Store(true) |
||||
|
||||
var ( |
||||
wg sync.WaitGroup |
||||
proxyErr error |
||||
adminErr error |
||||
) |
||||
|
||||
wg.Go(func() { |
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
||||
defer cancel() |
||||
if err := s.proxyServer.Shutdown(ctx); err != nil { |
||||
s.logger.Error(err, "error shutting down proxy server") |
||||
proxyErr = err |
||||
} |
||||
}) |
||||
|
||||
wg.Go(func() { |
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) |
||||
defer cancel() |
||||
if err := s.adminServer.Shutdown(ctx); err != nil { |
||||
s.logger.Error(err, "error shutting down admin server") |
||||
adminErr = err |
||||
} |
||||
}) |
||||
|
||||
wg.Wait() |
||||
return errors.Join(proxyErr, adminErr) |
||||
} |
||||
@ -0,0 +1,231 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"crypto/x509" |
||||
"errors" |
||||
"fmt" |
||||
"os" |
||||
"strings" |
||||
|
||||
"github.com/ViaQ/logerr/v2/kverrors" |
||||
) |
||||
|
||||
var errUnknownClientAuthType = errors.New("unknown client auth type") |
||||
|
||||
type TLSConfig struct { |
||||
MinVersion string |
||||
MaxVersion string |
||||
CipherSuites []string |
||||
CurvePrefs []string |
||||
ClientAuth string |
||||
} |
||||
|
||||
var tlsClientAuthTypes = map[string]tls.ClientAuthType{ |
||||
"NoClientCert": tls.NoClientCert, |
||||
"RequestClientCert": tls.RequestClientCert, |
||||
"RequireAnyClientCert": tls.RequireAnyClientCert, |
||||
"VerifyClientCertIfGiven": tls.VerifyClientCertIfGiven, |
||||
"RequireAndVerifyClientCert": tls.RequireAndVerifyClientCert, |
||||
} |
||||
|
||||
var tlsVersions = map[string]uint16{ |
||||
"VersionTLS10": tls.VersionTLS10, |
||||
"VersionTLS11": tls.VersionTLS11, |
||||
"VersionTLS12": tls.VersionTLS12, |
||||
"VersionTLS13": tls.VersionTLS13, |
||||
} |
||||
|
||||
var tlsCipherSuites = map[string]uint16{ |
||||
// TLS 1.2 cipher suites
|
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, |
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, |
||||
"TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256": tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, |
||||
"TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384": tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, |
||||
"TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, |
||||
"TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256": tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, |
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, |
||||
"TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, |
||||
"TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA": tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, |
||||
"TLS_RSA_WITH_AES_128_GCM_SHA256": tls.TLS_RSA_WITH_AES_128_GCM_SHA256, |
||||
"TLS_RSA_WITH_AES_256_GCM_SHA384": tls.TLS_RSA_WITH_AES_256_GCM_SHA384, |
||||
"TLS_RSA_WITH_AES_128_CBC_SHA256": tls.TLS_RSA_WITH_AES_128_CBC_SHA256, |
||||
// TLS 1.3 cipher suites (always enabled when TLS 1.3 is used)
|
||||
"TLS_AES_128_GCM_SHA256": tls.TLS_AES_128_GCM_SHA256, |
||||
"TLS_AES_256_GCM_SHA384": tls.TLS_AES_256_GCM_SHA384, |
||||
"TLS_CHACHA20_POLY1305_SHA256": tls.TLS_CHACHA20_POLY1305_SHA256, |
||||
} |
||||
|
||||
var tlsCurveIDs = map[string]tls.CurveID{ |
||||
"X25519": tls.X25519, |
||||
"CurveP256": tls.CurveP256, |
||||
"CurveP384": tls.CurveP384, |
||||
"CurveP521": tls.CurveP521, |
||||
} |
||||
|
||||
// BuildServerTLSConfig creates a TLS config for the server with optional mTLS.
|
||||
func BuildServerTLSConfig(cfg *TLSConfig, clientCAFile string) (*tls.Config, error) { |
||||
tlsConfig, err := buildTLSConfig(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if clientCAFile != "" { |
||||
caCert, err := os.ReadFile(clientCAFile) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to read client CA file: %w", err) |
||||
} |
||||
|
||||
caCertPool := x509.NewCertPool() |
||||
if !caCertPool.AppendCertsFromPEM(caCert) { |
||||
return nil, kverrors.New("failed to append client CA certs") |
||||
} |
||||
|
||||
tlsConfig.ClientCAs = caCertPool |
||||
} |
||||
|
||||
return tlsConfig, nil |
||||
} |
||||
|
||||
// BuildServerTLSConfigWithClientAuth creates a TLS config for the server with optional client authentication.
|
||||
func BuildServerTLSConfigWithClientAuth(cfg *TLSConfig, clientCAFile string) (*tls.Config, error) { |
||||
tlsConfig, err := BuildServerTLSConfig(cfg, clientCAFile) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
if clientCAFile != "" { |
||||
clientAuth, err := parseClientAuthType(cfg.ClientAuth) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
tlsConfig.ClientAuth = clientAuth |
||||
} |
||||
|
||||
return tlsConfig, nil |
||||
} |
||||
|
||||
// BuildUpstreamTLSConfig creates a TLS config for upstream connections with optional client certificate.
|
||||
func BuildUpstreamTLSConfig(cfg *TLSConfig, upstreamCAFile, upstreamCertFile, upstreamKeyFile string) (*tls.Config, error) { |
||||
tlsConfig, err := buildTLSConfig(cfg) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
// Load upstream CA for server verification
|
||||
if upstreamCAFile != "" { |
||||
caCert, err := os.ReadFile(upstreamCAFile) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to read upstream CA file: %w", err) |
||||
} |
||||
|
||||
caCertPool := x509.NewCertPool() |
||||
if !caCertPool.AppendCertsFromPEM(caCert) { |
||||
return nil, kverrors.New("failed to append upstream CA certs") |
||||
} |
||||
tlsConfig.RootCAs = caCertPool |
||||
} |
||||
|
||||
// Load client certificate for upstream mTLS
|
||||
if upstreamCertFile != "" && upstreamKeyFile != "" { |
||||
cert, err := tls.LoadX509KeyPair(upstreamCertFile, upstreamKeyFile) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to load upstream client certificate: %w", err) |
||||
} |
||||
tlsConfig.Certificates = []tls.Certificate{cert} |
||||
} |
||||
|
||||
return tlsConfig, nil |
||||
} |
||||
|
||||
func parseClientAuthType(authType string) (tls.ClientAuthType, error) { |
||||
if authType == "" { |
||||
return tls.RequireAndVerifyClientCert, nil |
||||
} |
||||
auth, ok := tlsClientAuthTypes[authType] |
||||
if !ok { |
||||
return 0, fmt.Errorf("%w: %s", errUnknownClientAuthType, authType) |
||||
} |
||||
return auth, nil |
||||
} |
||||
|
||||
func buildTLSConfig(cfg *TLSConfig) (*tls.Config, error) { |
||||
tlsConfig := &tls.Config{ |
||||
MinVersion: tls.VersionTLS12, |
||||
} |
||||
|
||||
if cfg.MinVersion != "" { |
||||
minVersion, err := parseTLSVersion(cfg.MinVersion) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid minimum TLS version: %w", err) |
||||
} |
||||
tlsConfig.MinVersion = minVersion |
||||
} |
||||
|
||||
if cfg.MaxVersion != "" { |
||||
maxVersion, err := parseTLSVersion(cfg.MaxVersion) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid maximum TLS version: %w", err) |
||||
} |
||||
tlsConfig.MaxVersion = maxVersion |
||||
} |
||||
|
||||
if len(cfg.CipherSuites) > 0 { |
||||
cipherSuites, err := parseCipherSuites(cfg.CipherSuites) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid cipher suites: %w", err) |
||||
} |
||||
tlsConfig.CipherSuites = cipherSuites |
||||
} |
||||
|
||||
if len(cfg.CurvePrefs) > 0 { |
||||
curvePrefs, err := parseCurvePreferences(cfg.CurvePrefs) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("invalid curve preferences: %w", err) |
||||
} |
||||
tlsConfig.CurvePreferences = curvePrefs |
||||
} |
||||
|
||||
return tlsConfig, nil |
||||
} |
||||
|
||||
func parseTLSVersion(version string) (uint16, error) { |
||||
if version == "" { |
||||
return 0, nil |
||||
} |
||||
v, ok := tlsVersions[version] |
||||
if !ok { |
||||
return 0, kverrors.New("unknown TLS version", "version", version) |
||||
} |
||||
return v, nil |
||||
} |
||||
|
||||
func parseCipherSuites(ciphers []string) ([]uint16, error) { |
||||
if len(ciphers) == 0 { |
||||
return nil, nil |
||||
} |
||||
result := make([]uint16, 0, len(ciphers)) |
||||
for _, name := range ciphers { |
||||
id, ok := tlsCipherSuites[strings.TrimSpace(name)] |
||||
if !ok { |
||||
return nil, kverrors.New("unknown cipher suite", "cipher", name) |
||||
} |
||||
result = append(result, id) |
||||
} |
||||
return result, nil |
||||
} |
||||
|
||||
func parseCurvePreferences(curves []string) ([]tls.CurveID, error) { |
||||
if len(curves) == 0 { |
||||
return nil, nil |
||||
} |
||||
result := make([]tls.CurveID, 0, len(curves)) |
||||
for _, name := range curves { |
||||
id, ok := tlsCurveIDs[strings.TrimSpace(name)] |
||||
if !ok { |
||||
return nil, kverrors.New("unknown curve", "curve", name) |
||||
} |
||||
result = append(result, id) |
||||
} |
||||
return result, nil |
||||
} |
||||
@ -0,0 +1,307 @@ |
||||
package passthroughgateway |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestBuildTLSConfig(t *testing.T) { |
||||
tt := []struct { |
||||
desc string |
||||
cfg *TLSConfig |
||||
wantErr bool |
||||
errContains string |
||||
validate func(t *testing.T, cfg *tls.Config) |
||||
}{ |
||||
{ |
||||
desc: "default config with TLS 1.2 min version", |
||||
cfg: &TLSConfig{}, |
||||
validate: func(t *testing.T, cfg *tls.Config) { |
||||
require.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion) |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "custom min and maximum version", |
||||
cfg: &TLSConfig{ |
||||
MinVersion: "VersionTLS12", |
||||
MaxVersion: "VersionTLS13", |
||||
}, |
||||
validate: func(t *testing.T, cfg *tls.Config) { |
||||
require.Equal(t, uint16(tls.VersionTLS12), cfg.MinVersion) |
||||
require.Equal(t, uint16(tls.VersionTLS13), cfg.MaxVersion) |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "invalid min version", |
||||
cfg: &TLSConfig{ |
||||
MinVersion: "TLS99", |
||||
}, |
||||
wantErr: true, |
||||
errContains: "invalid minimum TLS version", |
||||
}, |
||||
{ |
||||
desc: "invalid max version", |
||||
cfg: &TLSConfig{ |
||||
MaxVersion: "invalid", |
||||
}, |
||||
wantErr: true, |
||||
errContains: "invalid maximum TLS version", |
||||
}, |
||||
{ |
||||
desc: "valid cipher suites", |
||||
cfg: &TLSConfig{ |
||||
CipherSuites: []string{ |
||||
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", |
||||
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", |
||||
}, |
||||
}, |
||||
validate: func(t *testing.T, cfg *tls.Config) { |
||||
require.Len(t, cfg.CipherSuites, 2) |
||||
require.Contains(t, cfg.CipherSuites, uint16(tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)) |
||||
require.Contains(t, cfg.CipherSuites, uint16(tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384)) |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "invalid cipher suite", |
||||
cfg: &TLSConfig{ |
||||
CipherSuites: []string{"INVALID_CIPHER"}, |
||||
}, |
||||
wantErr: true, |
||||
errContains: "invalid cipher suites", |
||||
}, |
||||
{ |
||||
desc: "valid curve preferences", |
||||
cfg: &TLSConfig{ |
||||
CurvePrefs: []string{"X25519", "CurveP256"}, |
||||
}, |
||||
validate: func(t *testing.T, cfg *tls.Config) { |
||||
require.Len(t, cfg.CurvePreferences, 2) |
||||
require.Contains(t, cfg.CurvePreferences, tls.X25519) |
||||
require.Contains(t, cfg.CurvePreferences, tls.CurveP256) |
||||
}, |
||||
}, |
||||
{ |
||||
desc: "invalid curve preference", |
||||
cfg: &TLSConfig{ |
||||
CurvePrefs: []string{"INVALID_CURVE"}, |
||||
}, |
||||
wantErr: true, |
||||
errContains: "invalid curve preferences", |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tt { |
||||
t.Run(tc.desc, func(t *testing.T) { |
||||
cfg, err := buildTLSConfig(tc.cfg) |
||||
if tc.wantErr { |
||||
require.Error(t, err) |
||||
require.Contains(t, err.Error(), tc.errContains) |
||||
return |
||||
} |
||||
require.NoError(t, err) |
||||
if tc.validate != nil { |
||||
tc.validate(t, cfg) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestParseTLSVersion(t *testing.T) { |
||||
tt := []struct { |
||||
desc string |
||||
version string |
||||
want uint16 |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
desc: "empty version returns zero", |
||||
version: "", |
||||
want: 0, |
||||
}, |
||||
{ |
||||
desc: "TLS 1.0", |
||||
version: "VersionTLS10", |
||||
want: tls.VersionTLS10, |
||||
}, |
||||
{ |
||||
desc: "TLS 1.2", |
||||
version: "VersionTLS12", |
||||
want: tls.VersionTLS12, |
||||
}, |
||||
{ |
||||
desc: "TLS 1.3", |
||||
version: "VersionTLS13", |
||||
want: tls.VersionTLS13, |
||||
}, |
||||
{ |
||||
desc: "unknown version", |
||||
version: "TLS99", |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tt { |
||||
t.Run(tc.desc, func(t *testing.T) { |
||||
got, err := parseTLSVersion(tc.version) |
||||
if tc.wantErr { |
||||
require.Error(t, err) |
||||
return |
||||
} |
||||
require.NoError(t, err) |
||||
require.Equal(t, tc.want, got) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestParseClientAuthType(t *testing.T) { |
||||
tt := []struct { |
||||
desc string |
||||
authType string |
||||
want tls.ClientAuthType |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
desc: "empty defaults to RequireAndVerifyClientCert", |
||||
authType: "", |
||||
want: tls.RequireAndVerifyClientCert, |
||||
}, |
||||
{ |
||||
desc: "NoClientCert", |
||||
authType: "NoClientCert", |
||||
want: tls.NoClientCert, |
||||
}, |
||||
{ |
||||
desc: "RequestClientCert", |
||||
authType: "RequestClientCert", |
||||
want: tls.RequestClientCert, |
||||
}, |
||||
{ |
||||
desc: "RequireAnyClientCert", |
||||
authType: "RequireAnyClientCert", |
||||
want: tls.RequireAnyClientCert, |
||||
}, |
||||
{ |
||||
desc: "VerifyClientCertIfGiven", |
||||
authType: "VerifyClientCertIfGiven", |
||||
want: tls.VerifyClientCertIfGiven, |
||||
}, |
||||
{ |
||||
desc: "RequireAndVerifyClientCert", |
||||
authType: "RequireAndVerifyClientCert", |
||||
want: tls.RequireAndVerifyClientCert, |
||||
}, |
||||
{ |
||||
desc: "unknown auth type", |
||||
authType: "InvalidAuthType", |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tt { |
||||
t.Run(tc.desc, func(t *testing.T) { |
||||
got, err := parseClientAuthType(tc.authType) |
||||
if tc.wantErr { |
||||
require.Error(t, err) |
||||
return |
||||
} |
||||
require.NoError(t, err) |
||||
require.Equal(t, tc.want, got) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestParseCipherSuites(t *testing.T) { |
||||
tt := []struct { |
||||
desc string |
||||
ciphers []string |
||||
want []uint16 |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
desc: "empty returns nil", |
||||
ciphers: []string{}, |
||||
want: nil, |
||||
}, |
||||
{ |
||||
desc: "single cipher", |
||||
ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"}, |
||||
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, |
||||
}, |
||||
{ |
||||
desc: "multiple ciphers", |
||||
ciphers: []string{"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256"}, |
||||
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_AES_128_GCM_SHA256}, |
||||
}, |
||||
{ |
||||
desc: "cipher with whitespace", |
||||
ciphers: []string{" TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 "}, |
||||
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, |
||||
}, |
||||
{ |
||||
desc: "unknown cipher", |
||||
ciphers: []string{"UNKNOWN_CIPHER"}, |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tt { |
||||
t.Run(tc.desc, func(t *testing.T) { |
||||
got, err := parseCipherSuites(tc.ciphers) |
||||
if tc.wantErr { |
||||
require.Error(t, err) |
||||
return |
||||
} |
||||
require.NoError(t, err) |
||||
require.Equal(t, tc.want, got) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestParseCurvePreferences(t *testing.T) { |
||||
tt := []struct { |
||||
desc string |
||||
curves []string |
||||
want []tls.CurveID |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
desc: "empty returns nil", |
||||
curves: []string{}, |
||||
want: nil, |
||||
}, |
||||
{ |
||||
desc: "X25519", |
||||
curves: []string{"X25519"}, |
||||
want: []tls.CurveID{tls.X25519}, |
||||
}, |
||||
{ |
||||
desc: "multiple curves", |
||||
curves: []string{"X25519", "CurveP256", "CurveP384"}, |
||||
want: []tls.CurveID{tls.X25519, tls.CurveP256, tls.CurveP384}, |
||||
}, |
||||
{ |
||||
desc: "curve with whitespace", |
||||
curves: []string{" CurveP256 "}, |
||||
want: []tls.CurveID{tls.CurveP256}, |
||||
}, |
||||
{ |
||||
desc: "unknown curve", |
||||
curves: []string{"UNKNOWN_CURVE"}, |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tt { |
||||
t.Run(tc.desc, func(t *testing.T) { |
||||
got, err := parseCurvePreferences(tc.curves) |
||||
if tc.wantErr { |
||||
require.Error(t, err) |
||||
return |
||||
} |
||||
require.NoError(t, err) |
||||
require.Equal(t, tc.want, got) |
||||
}) |
||||
} |
||||
} |
||||
@ -0,0 +1,26 @@ |
||||
# Build the passthrough-gateway binary |
||||
FROM golang:1.26.3 as builder |
||||
|
||||
WORKDIR /workspace |
||||
# Copy the Go Modules manifests |
||||
COPY api/ api/ |
||||
COPY go.mod go.sum ./ |
||||
# cache deps before building and copying source so that we don't need to re-download as much |
||||
# and so that source changes don't invalidate our downloaded layer |
||||
RUN go mod download |
||||
|
||||
# Copy the go source |
||||
COPY cmd/passthrough-gateway/main.go cmd/passthrough-gateway/main.go |
||||
COPY internal/ internal/ |
||||
|
||||
# Build |
||||
RUN CGO_ENABLED=0 GOOS=linux GO111MODULE=on go build -mod=readonly -o passthrough-gateway cmd/passthrough-gateway/main.go |
||||
|
||||
# Use distroless as minimal base image to package the passthrough-gateway binary |
||||
# Refer to https://github.com/GoogleContainerTools/distroless for more details |
||||
FROM gcr.io/distroless/static:nonroot |
||||
WORKDIR / |
||||
COPY --from=builder /workspace/passthrough-gateway . |
||||
USER 65532:65532 |
||||
|
||||
ENTRYPOINT ["/passthrough-gateway"] |
||||
Loading…
Reference in new issue