The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/auth/jwt/key_sets.go

268 lines
6.6 KiB

package jwt
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
jose "github.com/go-jose/go-jose/v3"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/setting"
)
var ErrFailedToParsePemFile = errors.New("failed to parse pem-encoded file")
var ErrKeySetIsNotConfigured = errors.New("key set for jwt verification is not configured")
var ErrKeySetConfigurationAmbiguous = errors.New("key set configuration is ambiguous: you should set either key_file, jwk_set_file or jwk_set_url")
var ErrJWTSetURLMustHaveHTTPSScheme = errors.New("jwt_set_url must have https scheme")
type keySet interface {
Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error)
}
type keySetJWKS struct {
jose.JSONWebKeySet
}
type keySetHTTP struct {
url string
log log.Logger
client *http.Client
cache *remotecache.RemoteCache
cacheKey string
cacheExpiration time.Duration
}
func (s *AuthService) checkKeySetConfiguration() error {
var count int
if s.Cfg.JWTAuth.KeyFile != "" {
count++
}
if s.Cfg.JWTAuth.JWKSetFile != "" {
count++
}
if s.Cfg.JWTAuth.JWKSetURL != "" {
count++
}
if count == 0 {
return ErrKeySetIsNotConfigured
}
if count > 1 {
return ErrKeySetConfigurationAmbiguous
}
return nil
}
func (s *AuthService) initKeySet() error {
if err := s.checkKeySetConfiguration(); err != nil {
return err
}
if keyFilePath := s.Cfg.JWTAuth.KeyFile; keyFilePath != "" {
// nolint:gosec
// We can ignore the gosec G304 warning on this one because `fileName` comes from grafana configuration file
file, err := os.Open(keyFilePath)
if err != nil {
return err
}
defer func() {
if err := file.Close(); err != nil {
s.log.Warn("Failed to close file", "path", keyFilePath, "err", err)
}
}()
data, err := io.ReadAll(file)
if err != nil {
return err
}
block, _ := pem.Decode(data)
if block == nil {
return ErrFailedToParsePemFile
}
var key any
switch block.Type {
case "PUBLIC KEY":
if key, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil {
return err
}
case "PRIVATE KEY":
if key, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
return err
}
case "RSA PUBLIC KEY":
if key, err = x509.ParsePKCS1PublicKey(block.Bytes); err != nil {
return err
}
case "RSA PRIVATE KEY":
if key, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
return err
}
case "EC PRIVATE KEY":
if key, err = x509.ParseECPrivateKey(block.Bytes); err != nil {
return err
}
default:
return fmt.Errorf("unknown pem block type %q", block.Type)
}
s.keySet = &keySetJWKS{
jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{{Key: key, KeyID: s.Cfg.JWTAuth.KeyID}},
},
}
} else if keyFilePath := s.Cfg.JWTAuth.JWKSetFile; keyFilePath != "" {
// nolint:gosec
// We can ignore the gosec G304 warning on this one because `fileName` comes from grafana configuration file
file, err := os.Open(keyFilePath)
if err != nil {
return err
}
defer func() {
if err := file.Close(); err != nil {
s.log.Warn("Failed to close file", "path", keyFilePath, "err", err)
}
}()
var jwks jose.JSONWebKeySet
if err := json.NewDecoder(file).Decode(&jwks); err != nil {
return err
}
s.keySet = &keySetJWKS{jwks}
} else if urlStr := s.Cfg.JWTAuth.JWKSetURL; urlStr != "" {
urlParsed, err := url.Parse(urlStr)
if err != nil {
return err
}
if urlParsed.Scheme != "https" && s.Cfg.Env != setting.Dev {
return ErrJWTSetURLMustHaveHTTPSScheme
}
s.keySet = &keySetHTTP{
url: urlStr,
log: s.log,
client: &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateFreelyAsClient,
},
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: time.Second * 30,
KeepAlive: 15 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: 100,
IdleConnTimeout: 30 * time.Second,
},
Timeout: time.Second * 30,
},
cacheKey: fmt.Sprintf("auth-jwt:jwk-%s", urlStr),
cacheExpiration: s.Cfg.JWTAuth.CacheTTL,
cache: s.RemoteCache,
}
}
return nil
}
func (ks *keySetJWKS) Key(ctx context.Context, keyID string) ([]jose.JSONWebKey, error) {
return ks.JSONWebKeySet.Key(keyID), nil
}
func (ks *keySetHTTP) getJWKS(ctx context.Context) (keySetJWKS, error) {
var jwks keySetJWKS
if ks.cacheExpiration > 0 {
if val, err := ks.cache.Get(ctx, ks.cacheKey); err == nil {
err := json.Unmarshal(val, &jwks)
if err != nil {
ks.log.Warn("Failed to unmarshal key set from cache", "err", err)
} else {
return jwks, err
}
}
}
ks.log.Debug("Getting key set from endpoint", "url", ks.url)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ks.url, nil)
if err != nil {
return jwks, err
}
resp, err := ks.client.Do(req)
if err != nil {
return jwks, err
}
defer func() {
if err := resp.Body.Close(); err != nil {
ks.log.Warn("Failed to close response body", "err", err)
}
}()
var jsonBuf bytes.Buffer
if err := json.NewDecoder(io.TeeReader(resp.Body, &jsonBuf)).Decode(&jwks); err != nil {
return jwks, err
}
if ks.cacheExpiration > 0 {
cacheExpiration := ks.getCacheExpiration(resp.Header.Get("cache-control"))
ks.log.Debug("Setting key set in cache", "url", ks.url,
"cacheExpiration", cacheExpiration, "cacheControl", resp.Header.Get("cache-control"))
err = ks.cache.Set(ctx, ks.cacheKey, jsonBuf.Bytes(), cacheExpiration)
}
return jwks, err
}
func (ks *keySetHTTP) getCacheExpiration(cacheControl string) time.Duration {
cacheDuration := ks.cacheExpiration
if cacheControl == "" {
return cacheDuration
}
parts := strings.Split(cacheControl, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "max-age=") {
maxAge, err := strconv.Atoi(part[8:])
if err != nil {
return cacheDuration
}
// If the cache duration is 0 or the max-age is less than the cache duration, use the max-age
if cacheDuration == 0 || time.Duration(maxAge)*time.Second < cacheDuration {
return time.Duration(maxAge) * time.Second
}
}
}
return cacheDuration
}
func (ks keySetHTTP) Key(ctx context.Context, kid string) ([]jose.JSONWebKey, error) {
jwks, err := ks.getJWKS(ctx)
if err != nil {
return nil, err
}
return jwks.Key(ctx, kid)
}