From 80dfa788c6c9c9be9a672eb17d80183894a3f6ef Mon Sep 17 00:00:00 2001 From: Gabriel MABILLE Date: Tue, 4 Oct 2022 13:48:15 +0200 Subject: [PATCH] Azure OAuth: Use TID from id_token by default (#56264) Co-authored-by: Kalle Persson Co-authored-by: Kalle Persson --- pkg/login/social/azuread_oauth.go | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/pkg/login/social/azuread_oauth.go b/pkg/login/social/azuread_oauth.go index c5c4fb4ca94..a1aaeb7b7a9 100644 --- a/pkg/login/social/azuread_oauth.go +++ b/pkg/login/social/azuread_oauth.go @@ -29,6 +29,7 @@ type azureClaims struct { ID string `json:"oid"` ClaimNames claimNames `json:"_claim_names,omitempty"` ClaimSources map[string]claimSource `json:"_claim_sources,omitempty"` + TenantID string `json:"tid,omitempty"` } type claimNames struct { @@ -177,20 +178,27 @@ func extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) // If user groups exceeds 200 no groups will be found in claims. // See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim endpoint := claims.ClaimSources[claims.ClaimNames.Groups].Endpoint + + // If the endpoints provided in _claim_source is pointing to the deprecated "graph.windows.net" api + // replace with handcrafted url to graph.microsoft.com + // See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview if strings.Contains(endpoint, "graph.windows.net") { - // If the endpoints provided in _claim_source is pointed to the deprecated "graph.windows.net" api - // replace with handcrafted url to graph.microsoft.com - // See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview - parsedToken, err := jwt.ParseSigned(token.AccessToken) - if err != nil { - return nil, fmt.Errorf("error parsing id token: %w", err) - } + tenantID := claims.TenantID + // If tenantID wasn't found in the id_token, parse access token + if tenantID == "" { + parsedToken, err := jwt.ParseSigned(token.AccessToken) + if err != nil { + return nil, fmt.Errorf("error parsing access token: %w", err) + } - var accessClaims azureAccessClaims - if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil { - return nil, fmt.Errorf("error getting claims from access token: %w", err) + var accessClaims azureAccessClaims + if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil { + return nil, fmt.Errorf("error getting claims from access token: %w", err) + } + tenantID = accessClaims.TenantID } - endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", accessClaims.TenantID, claims.ID) + + endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", tenantID, claims.ID) } data, err := json.Marshal(&getAzureGroupRequest{SecurityEnabledOnly: false})