|
|
|
@ -17,7 +17,8 @@ import ( |
|
|
|
|
|
|
|
|
|
type SocialAzureAD struct { |
|
|
|
|
*SocialBase |
|
|
|
|
allowedGroups []string |
|
|
|
|
allowedGroups []string |
|
|
|
|
forceUseGraphAPI bool |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type azureClaims struct { |
|
|
|
@ -76,7 +77,7 @@ func (s *SocialAzureAD) UserInfo(client *http.Client, token *oauth2.Token) (*Bas |
|
|
|
|
|
|
|
|
|
logger.Debug("AzureAD OAuth: extracted role", "email", email, "role", role) |
|
|
|
|
|
|
|
|
|
groups, err := extractGroups(client, claims, token) |
|
|
|
|
groups, err := s.extractGroups(client, claims, token) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("failed to extract groups: %w", err) |
|
|
|
|
} |
|
|
|
@ -166,39 +167,26 @@ type getAzureGroupResponse struct { |
|
|
|
|
Value []string `json:"value"` |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) { |
|
|
|
|
if len(claims.Groups) > 0 { |
|
|
|
|
return claims.Groups, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if claims.ClaimNames.Groups == "" { |
|
|
|
|
return []string{}, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If user groups exceeds 200 no groups will be found in claims.
|
|
|
|
|
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
|
|
|
|
endpoint := claims.ClaimSources[claims.ClaimNames.Groups].Endpoint |
|
|
|
|
|
|
|
|
|
// If the endpoints provided in _claim_source is pointing to the deprecated "graph.windows.net" api
|
|
|
|
|
// replace with handcrafted url to graph.microsoft.com
|
|
|
|
|
// See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview
|
|
|
|
|
if strings.Contains(endpoint, "graph.windows.net") { |
|
|
|
|
tenantID := claims.TenantID |
|
|
|
|
// If tenantID wasn't found in the id_token, parse access token
|
|
|
|
|
if tenantID == "" { |
|
|
|
|
parsedToken, err := jwt.ParseSigned(token.AccessToken) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, fmt.Errorf("error parsing access token: %w", err) |
|
|
|
|
} |
|
|
|
|
// extractGroups retrieves groups from the claims.
|
|
|
|
|
// Note: If user groups exceeds 200 no groups will be found in claims and URL to target the Graph API will be
|
|
|
|
|
// given instead.
|
|
|
|
|
// See https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#groups-overage-claim
|
|
|
|
|
func (s *SocialAzureAD) extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) ([]string, error) { |
|
|
|
|
if !s.forceUseGraphAPI { |
|
|
|
|
logger.Debug("checking the claim for groups") |
|
|
|
|
if len(claims.Groups) > 0 { |
|
|
|
|
return claims.Groups, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var accessClaims azureAccessClaims |
|
|
|
|
if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil { |
|
|
|
|
return nil, fmt.Errorf("error getting claims from access token: %w", err) |
|
|
|
|
} |
|
|
|
|
tenantID = accessClaims.TenantID |
|
|
|
|
if claims.ClaimNames.Groups == "" { |
|
|
|
|
return []string{}, nil |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", tenantID, claims.ID) |
|
|
|
|
// Fallback to the Graph API
|
|
|
|
|
endpoint, errBuildGraphURI := groupsGraphAPIURL(claims, token) |
|
|
|
|
if errBuildGraphURI != nil { |
|
|
|
|
return nil, errBuildGraphURI |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
data, err := json.Marshal(&getAzureGroupRequest{SecurityEnabledOnly: false}) |
|
|
|
@ -234,3 +222,38 @@ func extractGroups(client *http.Client, claims azureClaims, token *oauth2.Token) |
|
|
|
|
|
|
|
|
|
return body.Value, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// groupsGraphAPIURL retrieves the Microsoft Graph API URL to fetch user groups from the _claim_sources if present
|
|
|
|
|
// otherwise it generates an handcrafted URL.
|
|
|
|
|
func groupsGraphAPIURL(claims azureClaims, token *oauth2.Token) (string, error) { |
|
|
|
|
var endpoint string |
|
|
|
|
// First check if an endpoint was specified in the claims
|
|
|
|
|
if claims.ClaimNames.Groups != "" { |
|
|
|
|
endpoint = claims.ClaimSources[claims.ClaimNames.Groups].Endpoint |
|
|
|
|
logger.Debug(fmt.Sprintf("endpoint to fetch groups specified in the claims: %s", endpoint)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// If no endpoint was specified or if the endpoints provided in _claim_source is pointing to the deprecated
|
|
|
|
|
// "graph.windows.net" api, use an handcrafted url to graph.microsoft.com
|
|
|
|
|
// See https://docs.microsoft.com/en-us/graph/migrate-azure-ad-graph-overview
|
|
|
|
|
if endpoint == "" || strings.Contains(endpoint, "graph.windows.net") { |
|
|
|
|
tenantID := claims.TenantID |
|
|
|
|
// If tenantID wasn't found in the id_token, parse access token
|
|
|
|
|
if tenantID == "" { |
|
|
|
|
parsedToken, err := jwt.ParseSigned(token.AccessToken) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", fmt.Errorf("error parsing access token: %w", err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var accessClaims azureAccessClaims |
|
|
|
|
if err := parsedToken.UnsafeClaimsWithoutVerification(&accessClaims); err != nil { |
|
|
|
|
return "", fmt.Errorf("error getting claims from access token: %w", err) |
|
|
|
|
} |
|
|
|
|
tenantID = accessClaims.TenantID |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
endpoint = fmt.Sprintf("https://graph.microsoft.com/v1.0/%s/users/%s/getMemberObjects", tenantID, claims.ID) |
|
|
|
|
logger.Debug(fmt.Sprintf("handcrafted endpoint to fetch groups: %s", endpoint)) |
|
|
|
|
} |
|
|
|
|
return endpoint, nil |
|
|
|
|
} |
|
|
|
|