|
|
|
@ -58,21 +58,30 @@ type Service struct { |
|
|
|
|
nsMapper request.NamespaceMapper |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (string, error) { |
|
|
|
|
func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (string, *authnlib.Claims[authnlib.IDTokenClaims], error) { |
|
|
|
|
defer func(t time.Time) { |
|
|
|
|
s.metrics.tokenSigningDurationHistogram.Observe(time.Since(t).Seconds()) |
|
|
|
|
}(time.Now()) |
|
|
|
|
|
|
|
|
|
cacheKey := prefixCacheKey(id.GetCacheKey()) |
|
|
|
|
|
|
|
|
|
result, err, _ := s.si.Do(cacheKey, func() (interface{}, error) { |
|
|
|
|
type resultType struct { |
|
|
|
|
token string |
|
|
|
|
idClaims *auth.IDClaims |
|
|
|
|
} |
|
|
|
|
result, err, _ := s.si.Do(cacheKey, func() (any, error) { |
|
|
|
|
namespace, identifier := id.GetTypedID() |
|
|
|
|
|
|
|
|
|
cachedToken, err := s.cache.Get(ctx, cacheKey) |
|
|
|
|
if err == nil { |
|
|
|
|
s.metrics.tokenSigningFromCacheCounter.Inc() |
|
|
|
|
s.logger.FromContext(ctx).Debug("Cached token found", "namespace", namespace, "id", identifier) |
|
|
|
|
return string(cachedToken), nil |
|
|
|
|
|
|
|
|
|
tokenClaims, err := s.extractTokenClaims(string(cachedToken)) |
|
|
|
|
if err != nil { |
|
|
|
|
return resultType{}, err |
|
|
|
|
} |
|
|
|
|
return resultType{token: string(cachedToken), idClaims: tokenClaims}, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
s.metrics.tokenSigningCounter.Inc() |
|
|
|
@ -104,21 +113,12 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri |
|
|
|
|
token, err := s.signer.SignIDToken(ctx, claims) |
|
|
|
|
if err != nil { |
|
|
|
|
s.metrics.failedTokenSigningCounter.Inc() |
|
|
|
|
return "", err |
|
|
|
|
return resultType{}, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
parsed, err := jwt.ParseSigned(token) |
|
|
|
|
extracted, err := s.extractTokenClaims(token) |
|
|
|
|
if err != nil { |
|
|
|
|
s.metrics.failedTokenSigningCounter.Inc() |
|
|
|
|
return "", err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
extracted := auth.IDClaims{} |
|
|
|
|
// We don't need to verify the signature here, we are only interested in checking
|
|
|
|
|
// when the token expires.
|
|
|
|
|
if err := parsed.UnsafeClaimsWithoutVerification(&extracted); err != nil { |
|
|
|
|
s.metrics.failedTokenSigningCounter.Inc() |
|
|
|
|
return "", err |
|
|
|
|
return resultType{}, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
expires := time.Until(extracted.Expiry.Time()) |
|
|
|
@ -126,14 +126,14 @@ func (s *Service) SignIdentity(ctx context.Context, id identity.Requester) (stri |
|
|
|
|
s.logger.FromContext(ctx).Error("Failed to add id token to cache", "error", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return token, nil |
|
|
|
|
return resultType{token: token, idClaims: claims}, nil |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
return "", err |
|
|
|
|
return "", nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return result.(string), nil |
|
|
|
|
return result.(resultType).token, result.(resultType).idClaims, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Service) RemoveIDToken(ctx context.Context, id identity.Requester) error { |
|
|
|
@ -142,7 +142,7 @@ func (s *Service) RemoveIDToken(ctx context.Context, id identity.Requester) erro |
|
|
|
|
|
|
|
|
|
func (s *Service) hook(ctx context.Context, identity *authn.Identity, _ *authn.Request) error { |
|
|
|
|
// FIXME(kalleep): we should probably lazy load this
|
|
|
|
|
token, err := s.SignIdentity(ctx, identity) |
|
|
|
|
token, claims, err := s.SignIdentity(ctx, identity) |
|
|
|
|
if err != nil { |
|
|
|
|
if shouldLogErr(err) { |
|
|
|
|
namespace, id := identity.GetTypedID() |
|
|
|
@ -153,9 +153,28 @@ func (s *Service) hook(ctx context.Context, identity *authn.Identity, _ *authn.R |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
identity.IDToken = token |
|
|
|
|
identity.IDTokenClaims = claims |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Service) extractTokenClaims(token string) (*authnlib.Claims[authnlib.IDTokenClaims], error) { |
|
|
|
|
parsed, err := jwt.ParseSigned(token) |
|
|
|
|
if err != nil { |
|
|
|
|
s.metrics.failedTokenSigningCounter.Inc() |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
extracted := authnlib.Claims[authnlib.IDTokenClaims]{} |
|
|
|
|
// We don't need to verify the signature here, we are only interested in checking
|
|
|
|
|
// when the token expires.
|
|
|
|
|
if err := parsed.UnsafeClaimsWithoutVerification(&extracted); err != nil { |
|
|
|
|
s.metrics.failedTokenSigningCounter.Inc() |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return &extracted, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func getAudience(orgID int64) jwt.Audience { |
|
|
|
|
return jwt.Audience{fmt.Sprintf("org:%d", orgID)} |
|
|
|
|
} |
|
|
|
|