mirror of https://github.com/grafana/grafana
Contexthandler: Remove code that is no longer used (#73101)
* Contexthandler: remove dead code * Contexthandler: Add tests * Update pkg/tests/api/alerting/api_alertmanager_test.go Co-authored-by: Jo <joao.guerreiro@grafana.com> --------- Co-authored-by: Jo <joao.guerreiro@grafana.com>pull/73109/head
parent
5d8e6aa162
commit
e53e22ef2a
@ -1,222 +0,0 @@ |
||||
package contexthandler |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"net/http" |
||||
"strings" |
||||
|
||||
"github.com/jmespath/go-jmespath" |
||||
|
||||
"github.com/grafana/grafana/pkg/login" |
||||
"github.com/grafana/grafana/pkg/models/roletype" |
||||
authJWT "github.com/grafana/grafana/pkg/services/auth/jwt" |
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" |
||||
loginsvc "github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/org" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
) |
||||
|
||||
const ( |
||||
InvalidJWT = "Invalid JWT" |
||||
InvalidRole = "Invalid Role" |
||||
UserNotFound = "User not found" |
||||
authQueryParamName = "auth_token" |
||||
) |
||||
|
||||
func (h *ContextHandler) initContextWithJWT(ctx *contextmodel.ReqContext, orgId int64) bool { |
||||
if !h.Cfg.JWTAuthEnabled || h.Cfg.JWTAuthHeaderName == "" { |
||||
return false |
||||
} |
||||
|
||||
jwtToken := ctx.Req.Header.Get(h.Cfg.JWTAuthHeaderName) |
||||
if jwtToken == "" && h.Cfg.JWTAuthURLLogin { |
||||
jwtToken = ctx.Req.URL.Query().Get(authQueryParamName) |
||||
} |
||||
|
||||
if jwtToken == "" { |
||||
return false |
||||
} |
||||
|
||||
stripSensitiveParam(h.Cfg, ctx.Req) |
||||
|
||||
// Strip the 'Bearer' prefix if it exists.
|
||||
jwtToken = strings.TrimPrefix(jwtToken, "Bearer ") |
||||
|
||||
// If the "sub" claim is missing or empty then pass the control to the next handler
|
||||
if !authJWT.HasSubClaim(jwtToken) { |
||||
return false |
||||
} |
||||
|
||||
claims, err := h.JWTAuthService.Verify(ctx.Req.Context(), jwtToken) |
||||
if err != nil { |
||||
ctx.Logger.Debug("Failed to verify JWT", "error", err) |
||||
ctx.JsonApiErr(http.StatusUnauthorized, InvalidJWT, err) |
||||
return true |
||||
} |
||||
|
||||
query := user.GetSignedInUserQuery{OrgID: orgId} |
||||
|
||||
sub, _ := claims["sub"].(string) |
||||
|
||||
if sub == "" { |
||||
ctx.Logger.Warn("Got a JWT without the mandatory 'sub' claim", "error", err) |
||||
ctx.JsonApiErr(http.StatusUnauthorized, InvalidJWT, err) |
||||
return true |
||||
} |
||||
extUser := &loginsvc.ExternalUserInfo{ |
||||
AuthModule: "jwt", |
||||
AuthId: sub, |
||||
OrgRoles: map[int64]org.RoleType{}, |
||||
// we do not want to sync team memberships from JWT authentication see - https://github.com/grafana/grafana/issues/62175
|
||||
SkipTeamSync: true, |
||||
} |
||||
|
||||
if key := h.Cfg.JWTAuthUsernameClaim; key != "" { |
||||
query.Login, _ = claims[key].(string) |
||||
extUser.Login, _ = claims[key].(string) |
||||
} |
||||
if key := h.Cfg.JWTAuthEmailClaim; key != "" { |
||||
query.Email, _ = claims[key].(string) |
||||
extUser.Email, _ = claims[key].(string) |
||||
} |
||||
|
||||
if name, _ := claims["name"].(string); name != "" { |
||||
extUser.Name = name |
||||
} |
||||
|
||||
var role roletype.RoleType |
||||
var grafanaAdmin bool |
||||
if !h.Cfg.JWTAuthSkipOrgRoleSync { |
||||
role, grafanaAdmin = h.extractJWTRoleAndAdmin(claims) |
||||
if h.Cfg.JWTAuthRoleAttributeStrict && !role.IsValid() { |
||||
ctx.Logger.Debug("Extracted Role is invalid") |
||||
ctx.JsonApiErr(http.StatusForbidden, InvalidRole, nil) |
||||
return true |
||||
} |
||||
if role.IsValid() { |
||||
var orgID int64 |
||||
if h.Cfg.AutoAssignOrg && h.Cfg.AutoAssignOrgId > 0 { |
||||
orgID = int64(h.Cfg.AutoAssignOrgId) |
||||
ctx.Logger.Debug("The user has a role assignment and organization membership is auto-assigned", |
||||
"role", role, "orgId", orgID) |
||||
} else { |
||||
orgID = int64(1) |
||||
ctx.Logger.Debug("The user has a role assignment and organization membership is not auto-assigned", |
||||
"role", role, "orgId", orgID) |
||||
} |
||||
|
||||
extUser.OrgRoles[orgID] = role |
||||
if h.Cfg.JWTAuthAllowAssignGrafanaAdmin { |
||||
extUser.IsGrafanaAdmin = &grafanaAdmin |
||||
} |
||||
} |
||||
} |
||||
|
||||
if query.Login == "" && query.Email == "" { |
||||
ctx.Logger.Debug("Failed to get an authentication claim from JWT") |
||||
ctx.JsonApiErr(http.StatusUnauthorized, InvalidJWT, err) |
||||
return true |
||||
} |
||||
|
||||
if h.Cfg.JWTAuthAutoSignUp { |
||||
upsert := &loginsvc.UpsertUserCommand{ |
||||
ReqContext: ctx, |
||||
SignupAllowed: h.Cfg.JWTAuthAutoSignUp, |
||||
ExternalUser: extUser, |
||||
UserLookupParams: loginsvc.UserLookupParams{ |
||||
UserID: nil, |
||||
Login: &query.Login, |
||||
Email: &query.Email, |
||||
}, |
||||
} |
||||
if _, err := h.loginService.UpsertUser(ctx.Req.Context(), upsert); err != nil { |
||||
ctx.Logger.Error("Failed to upsert JWT user", "error", err) |
||||
return false |
||||
} |
||||
} |
||||
|
||||
queryResult, err := h.userService.GetSignedInUserWithCacheCtx(ctx.Req.Context(), &query) |
||||
if err != nil { |
||||
if errors.Is(err, user.ErrUserNotFound) { |
||||
ctx.Logger.Debug( |
||||
"Failed to find user using JWT claims", |
||||
"email_claim", query.Email, |
||||
"username_claim", query.Login, |
||||
) |
||||
err = login.ErrInvalidCredentials |
||||
ctx.JsonApiErr(http.StatusUnauthorized, UserNotFound, err) |
||||
} else { |
||||
ctx.Logger.Error("Failed to get signed in user", "error", err) |
||||
ctx.JsonApiErr(http.StatusUnauthorized, InvalidJWT, err) |
||||
} |
||||
return true |
||||
} |
||||
|
||||
ctx.SignedInUser = queryResult |
||||
ctx.IsSignedIn = true |
||||
|
||||
return true |
||||
} |
||||
|
||||
const roleGrafanaAdmin = "GrafanaAdmin" |
||||
|
||||
func (h *ContextHandler) extractJWTRoleAndAdmin(claims map[string]interface{}) (org.RoleType, bool) { |
||||
if h.Cfg.JWTAuthRoleAttributePath == "" { |
||||
return "", false |
||||
} |
||||
|
||||
role, err := searchClaimsForStringAttr(h.Cfg.JWTAuthRoleAttributePath, claims) |
||||
if err != nil || role == "" { |
||||
return "", false |
||||
} |
||||
|
||||
if role == roleGrafanaAdmin { |
||||
return org.RoleAdmin, true |
||||
} |
||||
return org.RoleType(role), false |
||||
} |
||||
|
||||
func searchClaimsForAttr(attributePath string, claims map[string]interface{}) (interface{}, error) { |
||||
if attributePath == "" { |
||||
return "", errors.New("no attribute path specified") |
||||
} |
||||
|
||||
if len(claims) == 0 { |
||||
return "", errors.New("empty claims provided") |
||||
} |
||||
|
||||
val, err := jmespath.Search(attributePath, claims) |
||||
if err != nil { |
||||
return "", fmt.Errorf("failed to search claims with provided path: %q: %w", attributePath, err) |
||||
} |
||||
|
||||
return val, nil |
||||
} |
||||
|
||||
func searchClaimsForStringAttr(attributePath string, claims map[string]interface{}) (string, error) { |
||||
val, err := searchClaimsForAttr(attributePath, claims) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
strVal, ok := val.(string) |
||||
if ok { |
||||
return strVal, nil |
||||
} |
||||
|
||||
return "", nil |
||||
} |
||||
|
||||
// remove sensitive query params
|
||||
// avoid JWT URL login passing auth_token in URL
|
||||
func stripSensitiveParam(cfg *setting.Cfg, httpRequest *http.Request) { |
||||
if cfg.JWTAuthURLLogin { |
||||
params := httpRequest.URL.Query() |
||||
if params.Has(authQueryParamName) { |
||||
params.Del(authQueryParamName) |
||||
httpRequest.URL.RawQuery = params.Encode() |
||||
} |
||||
} |
||||
} |
||||
@ -1,129 +0,0 @@ |
||||
package contexthandler |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/http" |
||||
"strconv" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/grafana/grafana/pkg/infra/remotecache" |
||||
"github.com/grafana/grafana/pkg/infra/tracing" |
||||
"github.com/grafana/grafana/pkg/infra/usagestats" |
||||
"github.com/grafana/grafana/pkg/services/anonymous/anontest" |
||||
"github.com/grafana/grafana/pkg/services/auth/authtest" |
||||
"github.com/grafana/grafana/pkg/services/auth/jwt" |
||||
"github.com/grafana/grafana/pkg/services/authn/authntest" |
||||
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy" |
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" |
||||
"github.com/grafana/grafana/pkg/services/featuremgmt" |
||||
"github.com/grafana/grafana/pkg/services/ldap/service" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/login/loginservice" |
||||
"github.com/grafana/grafana/pkg/services/org/orgtest" |
||||
"github.com/grafana/grafana/pkg/services/rendering" |
||||
"github.com/grafana/grafana/pkg/services/secrets/fakes" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/services/user/usertest" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/web" |
||||
) |
||||
|
||||
const userID = int64(1) |
||||
const orgID = int64(4) |
||||
|
||||
// Test initContextWithAuthProxy with a cached user ID that is no longer valid.
|
||||
//
|
||||
// In this case, the cache entry should be ignored/cleared and another attempt should be done to sign the user
|
||||
// in without cache.
|
||||
func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) { |
||||
const name = "markelog" |
||||
|
||||
svc := getContextHandler(t) |
||||
|
||||
req, err := http.NewRequest("POST", "http://example.com", nil) |
||||
require.NoError(t, err) |
||||
ctx := &contextmodel.ReqContext{ |
||||
Context: &web.Context{Req: req}, |
||||
Logger: log.New("Test"), |
||||
} |
||||
req.Header.Set(svc.Cfg.AuthProxyHeaderName, name) |
||||
h, err := authproxy.HashCacheKey(name) |
||||
require.NoError(t, err) |
||||
key := fmt.Sprintf(authproxy.CachePrefix, h) |
||||
|
||||
t.Logf("Injecting stale user ID in cache with key %q", key) |
||||
userIdPayload := []byte(strconv.FormatInt(int64(33), 10)) |
||||
err = svc.RemoteCache.Set(context.Background(), key, userIdPayload, 0) |
||||
require.NoError(t, err) |
||||
|
||||
authEnabled := svc.initContextWithAuthProxy(ctx, orgID) |
||||
require.True(t, authEnabled) |
||||
|
||||
require.Equal(t, userID, ctx.SignedInUser.UserID) |
||||
require.True(t, ctx.IsSignedIn) |
||||
|
||||
cachedByteArray, err := svc.RemoteCache.Get(context.Background(), key) |
||||
require.NoError(t, err) |
||||
|
||||
cacheUserId, err := strconv.ParseInt(string(cachedByteArray), 10, 64) |
||||
|
||||
require.NoError(t, err) |
||||
require.Equal(t, userID, cacheUserId) |
||||
} |
||||
|
||||
type fakeRenderService struct { |
||||
rendering.Service |
||||
} |
||||
|
||||
func getContextHandler(t *testing.T) *ContextHandler { |
||||
t.Helper() |
||||
|
||||
sqlStore := db.InitTestDB(t) |
||||
|
||||
cfg := setting.NewCfg() |
||||
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{ |
||||
Name: "database", |
||||
} |
||||
cfg.AuthProxyHeaderName = "X-Killa" |
||||
cfg.AuthProxyEnabled = true |
||||
cfg.AuthProxyHeaderProperty = "username" |
||||
remoteCacheSvc, err := remotecache.ProvideService(cfg, sqlStore, &usagestats.UsageStatsMock{}, fakes.NewFakeSecretsService()) |
||||
require.NoError(t, err) |
||||
userAuthTokenSvc := authtest.NewFakeUserAuthTokenService() |
||||
renderSvc := &fakeRenderService{} |
||||
authJWTSvc := jwt.NewFakeJWTService() |
||||
tracer := tracing.InitializeTracerForTest() |
||||
|
||||
loginService := loginservice.LoginServiceMock{ExpectedUser: &user.User{ID: userID}} |
||||
userService := usertest.FakeUserService{ |
||||
GetSignedInUserFn: func(ctx context.Context, query *user.GetSignedInUserQuery) (*user.SignedInUser, error) { |
||||
if query.UserID != userID { |
||||
return &user.SignedInUser{}, user.ErrUserNotFound |
||||
} |
||||
return &user.SignedInUser{ |
||||
UserID: userID, |
||||
OrgID: orgID, |
||||
}, nil |
||||
}, |
||||
} |
||||
orgService := orgtest.NewOrgServiceFake() |
||||
|
||||
authProxy := authproxy.ProvideAuthProxy(cfg, remoteCacheSvc, loginService, &userService, nil, service.NewLDAPFakeService()) |
||||
authenticator := &fakeAuthenticator{} |
||||
|
||||
return ProvideService(cfg, userAuthTokenSvc, authJWTSvc, remoteCacheSvc, |
||||
renderSvc, sqlStore, tracer, authProxy, loginService, nil, authenticator, |
||||
&userService, orgService, nil, featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{}, &anontest.FakeAnonymousSessionService{}) |
||||
} |
||||
|
||||
type fakeAuthenticator struct{} |
||||
|
||||
func (fa *fakeAuthenticator) AuthenticateUser(c context.Context, query *login.LoginUserQuery) error { |
||||
return nil |
||||
} |
||||
@ -1,384 +0,0 @@ |
||||
package authproxy |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/hex" |
||||
"errors" |
||||
"fmt" |
||||
"hash/fnv" |
||||
"net" |
||||
"net/mail" |
||||
"path" |
||||
"reflect" |
||||
"strconv" |
||||
"strings" |
||||
"time" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/db" |
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/grafana/grafana/pkg/infra/remotecache" |
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" |
||||
"github.com/grafana/grafana/pkg/services/ldap" |
||||
"github.com/grafana/grafana/pkg/services/ldap/service" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/org" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/util" |
||||
) |
||||
|
||||
const ( |
||||
|
||||
// CachePrefix is a prefix for the cache key
|
||||
CachePrefix = "auth-proxy-sync-ttl:%s" |
||||
) |
||||
|
||||
// supportedHeaders states the supported headers configuration fields
|
||||
var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups", "Role"} |
||||
|
||||
// AuthProxy struct
|
||||
type AuthProxy struct { |
||||
cfg *setting.Cfg |
||||
remoteCache *remotecache.RemoteCache |
||||
loginService login.Service |
||||
sqlStore db.DB |
||||
userService user.Service |
||||
ldapService service.LDAP |
||||
|
||||
logger log.Logger |
||||
} |
||||
|
||||
func ProvideAuthProxy(cfg *setting.Cfg, remoteCache *remotecache.RemoteCache, |
||||
loginService login.Service, userService user.Service, |
||||
sqlStore db.DB, ldapService service.LDAP) *AuthProxy { |
||||
return &AuthProxy{ |
||||
cfg: cfg, |
||||
remoteCache: remoteCache, |
||||
loginService: loginService, |
||||
sqlStore: sqlStore, |
||||
userService: userService, |
||||
logger: log.New("auth.proxy"), |
||||
ldapService: ldapService, |
||||
} |
||||
} |
||||
|
||||
// Error auth proxy specific error
|
||||
type Error struct { |
||||
Message string |
||||
DetailsError error |
||||
} |
||||
|
||||
// newError returns an Error.
|
||||
func newError(message string, err error) Error { |
||||
return Error{ |
||||
Message: message, |
||||
DetailsError: err, |
||||
} |
||||
} |
||||
|
||||
// Error returns the error message.
|
||||
func (err Error) Error() string { |
||||
return err.Message |
||||
} |
||||
|
||||
// IsEnabled checks if the auth proxy is enabled.
|
||||
func (auth *AuthProxy) IsEnabled() bool { |
||||
// Bail if the setting is not enabled
|
||||
return auth.cfg.AuthProxyEnabled |
||||
} |
||||
|
||||
// HasHeader checks if we have specified header
|
||||
func (auth *AuthProxy) HasHeader(reqCtx *contextmodel.ReqContext) bool { |
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) |
||||
return len(header) != 0 |
||||
} |
||||
|
||||
// IsAllowedIP returns whether provided IP is allowed.
|
||||
func (auth *AuthProxy) IsAllowedIP(ip string) error { |
||||
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 { |
||||
return nil |
||||
} |
||||
|
||||
proxies := strings.Split(auth.cfg.AuthProxyWhitelist, ",") |
||||
proxyObjs := make([]*net.IPNet, 0, len(proxies)) |
||||
for _, proxy := range proxies { |
||||
result, err := coerceProxyAddress(proxy) |
||||
if err != nil { |
||||
return newError("could not get the network", err) |
||||
} |
||||
|
||||
proxyObjs = append(proxyObjs, result) |
||||
} |
||||
|
||||
sourceIP, _, err := net.SplitHostPort(ip) |
||||
if err != nil { |
||||
return newError("could not parse address", err) |
||||
} |
||||
sourceObj := net.ParseIP(sourceIP) |
||||
|
||||
for _, proxyObj := range proxyObjs { |
||||
if proxyObj.Contains(sourceObj) { |
||||
return nil |
||||
} |
||||
} |
||||
|
||||
return newError("proxy authentication required", fmt.Errorf( |
||||
"request for user from %s is not from the authentication proxy", |
||||
sourceIP, |
||||
)) |
||||
} |
||||
|
||||
func HashCacheKey(key string) (string, error) { |
||||
hasher := fnv.New128a() |
||||
if _, err := hasher.Write([]byte(key)); err != nil { |
||||
return "", err |
||||
} |
||||
return hex.EncodeToString(hasher.Sum(nil)), nil |
||||
} |
||||
|
||||
// getKey forms a key for the cache based on the headers received as part of the authentication flow.
|
||||
// Our configuration supports multiple headers. The main header contains the email or username.
|
||||
// And the additional ones that allow us to specify extra attributes: Name, Email, Role, or Groups.
|
||||
func (auth *AuthProxy) getKey(reqCtx *contextmodel.ReqContext) (string, error) { |
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) |
||||
key := strings.TrimSpace(header) // start the key with the main header
|
||||
|
||||
auth.headersIterator(reqCtx, func(_, header string) { |
||||
key = strings.Join([]string{key, header}, "-") // compose the key with any additional headers
|
||||
}) |
||||
|
||||
hashedKey, err := HashCacheKey(key) |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
return fmt.Sprintf(CachePrefix, hashedKey), nil |
||||
} |
||||
|
||||
// Login logs in user ID by whatever means possible.
|
||||
func (auth *AuthProxy) Login(reqCtx *contextmodel.ReqContext, ignoreCache bool) (int64, error) { |
||||
if !ignoreCache { |
||||
// Error here means absent cache - we don't need to handle that
|
||||
id, err := auth.getUserViaCache(reqCtx) |
||||
if err == nil && id != 0 { |
||||
return id, nil |
||||
} |
||||
} |
||||
|
||||
if auth.cfg.LDAPAuthEnabled { |
||||
id, err := auth.LoginViaLDAP(reqCtx) |
||||
if err != nil { |
||||
if errors.Is(err, ldap.ErrInvalidCredentials) { |
||||
return 0, newError("proxy authentication required", ldap.ErrInvalidCredentials) |
||||
} |
||||
return 0, newError("failed to get the user", err) |
||||
} |
||||
|
||||
return id, nil |
||||
} |
||||
|
||||
id, err := auth.loginViaHeader(reqCtx) |
||||
if err != nil { |
||||
return 0, newError("failed to log in as user, specified in auth proxy header", err) |
||||
} |
||||
|
||||
return id, nil |
||||
} |
||||
|
||||
// getUserViaCache gets user ID from cache.
|
||||
func (auth *AuthProxy) getUserViaCache(reqCtx *contextmodel.ReqContext) (int64, error) { |
||||
cacheKey, err := auth.getKey(reqCtx) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
auth.logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey) |
||||
cachedValue, err := auth.remoteCache.Get(reqCtx.Req.Context(), cacheKey) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
userId, err := strconv.ParseInt(string(cachedValue), 10, 64) |
||||
if err != nil { |
||||
auth.logger.Debug("Failed getting user ID via auth cache", "error", err) |
||||
return 0, err |
||||
} |
||||
|
||||
auth.logger.Debug("Successfully got user ID via auth cache", "id", cachedValue) |
||||
return userId, nil |
||||
} |
||||
|
||||
// RemoveUserFromCache removes user from cache.
|
||||
func (auth *AuthProxy) RemoveUserFromCache(reqCtx *contextmodel.ReqContext) error { |
||||
cacheKey, err := auth.getKey(reqCtx) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
auth.logger.Debug("Removing user from auth cache", "cacheKey", cacheKey) |
||||
if err := auth.remoteCache.Delete(reqCtx.Req.Context(), cacheKey); err != nil { |
||||
return err |
||||
} |
||||
|
||||
auth.logger.Debug("Successfully removed user from auth cache", "cacheKey", cacheKey) |
||||
return nil |
||||
} |
||||
|
||||
// LoginViaLDAP logs in user via LDAP request
|
||||
func (auth *AuthProxy) LoginViaLDAP(reqCtx *contextmodel.ReqContext) (int64, error) { |
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) |
||||
|
||||
extUser, err := auth.ldapService.User(header) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
// Have to sync grafana and LDAP user during log in
|
||||
upsert := &login.UpsertUserCommand{ |
||||
ReqContext: reqCtx, |
||||
SignupAllowed: auth.cfg.LDAPAllowSignup, |
||||
ExternalUser: extUser, |
||||
UserLookupParams: login.UserLookupParams{ |
||||
Login: &extUser.Login, |
||||
Email: &extUser.Email, |
||||
UserID: nil, |
||||
}, |
||||
} |
||||
u, err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
return u.ID, nil |
||||
} |
||||
|
||||
// loginViaHeader logs in user from the header only
|
||||
func (auth *AuthProxy) loginViaHeader(reqCtx *contextmodel.ReqContext) (int64, error) { |
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) |
||||
extUser := &login.ExternalUserInfo{ |
||||
AuthModule: login.AuthProxyAuthModule, |
||||
AuthId: header, |
||||
} |
||||
|
||||
switch auth.cfg.AuthProxyHeaderProperty { |
||||
case "username": |
||||
extUser.Login = header |
||||
|
||||
emailAddr, emailErr := mail.ParseAddress(header) // only set Email if it can be parsed as an email address
|
||||
if emailErr == nil { |
||||
extUser.Email = emailAddr.Address |
||||
} |
||||
case "email": |
||||
extUser.Email = header |
||||
extUser.Login = header |
||||
default: |
||||
return 0, fmt.Errorf("auth proxy header property invalid") |
||||
} |
||||
|
||||
auth.headersIterator(reqCtx, func(field string, header string) { |
||||
switch field { |
||||
case "Groups": |
||||
extUser.Groups = util.SplitString(header) |
||||
case "Role": |
||||
// If Role header is specified, we update the user role of the default org
|
||||
if header != "" { |
||||
rt := org.RoleType(header) |
||||
if rt.IsValid() { |
||||
extUser.OrgRoles = map[int64]org.RoleType{} |
||||
orgID := int64(1) |
||||
if auth.cfg.AutoAssignOrg && auth.cfg.AutoAssignOrgId > 0 { |
||||
orgID = int64(auth.cfg.AutoAssignOrgId) |
||||
} |
||||
extUser.OrgRoles[orgID] = rt |
||||
} |
||||
} |
||||
default: |
||||
reflect.ValueOf(extUser).Elem().FieldByName(field).SetString(header) |
||||
} |
||||
}) |
||||
|
||||
upsert := &login.UpsertUserCommand{ |
||||
ReqContext: reqCtx, |
||||
SignupAllowed: auth.cfg.AuthProxyAutoSignUp, |
||||
ExternalUser: extUser, |
||||
UserLookupParams: login.UserLookupParams{ |
||||
UserID: nil, |
||||
Login: &extUser.Login, |
||||
Email: &extUser.Email, |
||||
}, |
||||
} |
||||
|
||||
result, err := auth.loginService.UpsertUser(reqCtx.Req.Context(), upsert) |
||||
if err != nil { |
||||
return 0, err |
||||
} |
||||
|
||||
return result.ID, nil |
||||
} |
||||
|
||||
// getDecodedHeader gets decoded value of a header with given headerName
|
||||
func (auth *AuthProxy) getDecodedHeader(reqCtx *contextmodel.ReqContext, headerName string) string { |
||||
headerValue := reqCtx.Req.Header.Get(headerName) |
||||
|
||||
if auth.cfg.AuthProxyHeadersEncoded { |
||||
headerValue = util.DecodeQuotedPrintable(headerValue) |
||||
} |
||||
|
||||
return headerValue |
||||
} |
||||
|
||||
// headersIterator iterates over all non-empty supported additional headers
|
||||
func (auth *AuthProxy) headersIterator(reqCtx *contextmodel.ReqContext, fn func(field string, header string)) { |
||||
for _, field := range supportedHeaderFields { |
||||
h := auth.cfg.AuthProxyHeaders[field] |
||||
if h == "" { |
||||
continue |
||||
} |
||||
|
||||
if value := auth.getDecodedHeader(reqCtx, h); value != "" { |
||||
fn(field, strings.TrimSpace(value)) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// GetSignedInUser gets full signed in user info.
|
||||
func (auth *AuthProxy) GetSignedInUser(userID int64, orgID int64) (*user.SignedInUser, error) { |
||||
return auth.userService.GetSignedInUser(context.Background(), &user.GetSignedInUserQuery{ |
||||
OrgID: orgID, |
||||
UserID: userID, |
||||
}) |
||||
} |
||||
|
||||
// Remember user in cache
|
||||
func (auth *AuthProxy) Remember(reqCtx *contextmodel.ReqContext, id int64) error { |
||||
key, err := auth.getKey(reqCtx) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
// Check if user already in cache
|
||||
cachedValue, err := auth.remoteCache.Get(reqCtx.Req.Context(), key) |
||||
if err == nil && len(cachedValue) != 0 { |
||||
return nil |
||||
} |
||||
|
||||
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute |
||||
|
||||
userIdPayload := []byte(strconv.FormatInt(id, 10)) |
||||
if err := auth.remoteCache.Set(reqCtx.Req.Context(), key, userIdPayload, expiration); err != nil { |
||||
return err |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// coerceProxyAddress gets network of the presented CIDR notation
|
||||
func coerceProxyAddress(proxyAddr string) (*net.IPNet, error) { |
||||
proxyAddr = strings.TrimSpace(proxyAddr) |
||||
if !strings.Contains(proxyAddr, "/") { |
||||
proxyAddr = path.Join(proxyAddr, "32") |
||||
} |
||||
|
||||
_, network, err := net.ParseCIDR(proxyAddr) |
||||
if err != nil { |
||||
return nil, fmt.Errorf("could not parse the network: %w", err) |
||||
} |
||||
return network, nil |
||||
} |
||||
@ -1,168 +0,0 @@ |
||||
package authproxy |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/http" |
||||
"strconv" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/remotecache" |
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" |
||||
"github.com/grafana/grafana/pkg/services/ldap/service" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/login/loginservice" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/web" |
||||
) |
||||
|
||||
const hdrName = "markelog" |
||||
const id int64 = 42 |
||||
|
||||
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, configureReq func(*http.Request, *setting.Cfg)) (*AuthProxy, *contextmodel.ReqContext) { |
||||
t.Helper() |
||||
|
||||
req, err := http.NewRequest("POST", "http://example.com", nil) |
||||
require.NoError(t, err) |
||||
|
||||
cfg := setting.NewCfg() |
||||
|
||||
if configureReq != nil { |
||||
configureReq(req, cfg) |
||||
} else { |
||||
cfg.AuthProxyHeaderName = "X-Killa" |
||||
req.Header.Set(cfg.AuthProxyHeaderName, hdrName) |
||||
} |
||||
|
||||
ctx := &contextmodel.ReqContext{ |
||||
Context: &web.Context{Req: req}, |
||||
} |
||||
|
||||
loginService := loginservice.LoginServiceMock{ |
||||
ExpectedUser: &user.User{ |
||||
ID: id, |
||||
}, |
||||
} |
||||
|
||||
return ProvideAuthProxy(cfg, remoteCache, loginService, nil, nil, service.NewLDAPFakeService()), ctx |
||||
} |
||||
|
||||
func TestMiddlewareContext(t *testing.T) { |
||||
cache := remotecache.NewFakeStore(t) |
||||
|
||||
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) { |
||||
const id int64 = 33 |
||||
// Set cache key
|
||||
h, err := HashCacheKey(hdrName) |
||||
require.NoError(t, err) |
||||
key := fmt.Sprintf(CachePrefix, h) |
||||
userIdPayload := []byte(strconv.FormatInt(id, 10)) |
||||
err = cache.Set(context.Background(), key, userIdPayload, 0) |
||||
require.NoError(t, err) |
||||
// Set up the middleware
|
||||
auth, reqCtx := prepareMiddleware(t, cache, nil) |
||||
gotKey, err := auth.getKey(reqCtx) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, key, gotKey) |
||||
|
||||
gotID, err := auth.Login(reqCtx, false) |
||||
require.NoError(t, err) |
||||
|
||||
assert.Equal(t, id, gotID) |
||||
}) |
||||
|
||||
t.Run("When the cache key contains additional headers", func(t *testing.T) { |
||||
const id int64 = 33 |
||||
const group = "grafana-core-team" |
||||
const role = "Admin" |
||||
|
||||
h, err := HashCacheKey(hdrName + "-" + group + "-" + role) |
||||
require.NoError(t, err) |
||||
key := fmt.Sprintf(CachePrefix, h) |
||||
userIdPayload := []byte(strconv.FormatInt(id, 10)) |
||||
err = cache.Set(context.Background(), key, userIdPayload, 0) |
||||
require.NoError(t, err) |
||||
|
||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { |
||||
cfg.AuthProxyHeaderName = "X-Killa" |
||||
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS", "Role": "X-WEBAUTH-ROLE"} |
||||
req.Header.Set(cfg.AuthProxyHeaderName, hdrName) |
||||
req.Header.Set("X-WEBAUTH-GROUPS", group) |
||||
req.Header.Set("X-WEBAUTH-ROLE", role) |
||||
}) |
||||
assert.Equal(t, "auth-proxy-sync-ttl:f5acfffd56daac98d502ef8c8b8c5d56", key) |
||||
|
||||
gotID, err := auth.Login(reqCtx, false) |
||||
require.NoError(t, err) |
||||
assert.Equal(t, id, gotID) |
||||
}) |
||||
} |
||||
|
||||
func TestMiddlewareContext_ldap(t *testing.T) { |
||||
t.Run("Logs in via LDAP", func(t *testing.T) { |
||||
cache := remotecache.NewFakeStore(t) |
||||
|
||||
auth, reqCtx := prepareMiddleware(t, cache, nil) |
||||
auth.cfg.LDAPAuthEnabled = true |
||||
ldapFake := &service.LDAPFakeService{ |
||||
ExpectedUser: &login.ExternalUserInfo{UserId: id}, |
||||
} |
||||
|
||||
auth.ldapService = ldapFake |
||||
|
||||
gotID, err := auth.Login(reqCtx, false) |
||||
require.NoError(t, err) |
||||
|
||||
assert.Equal(t, id, gotID) |
||||
assert.True(t, ldapFake.UserCalled) |
||||
}) |
||||
|
||||
t.Run("Gets nice error if LDAP is enabled, but not configured", func(t *testing.T) { |
||||
const id int64 = 42 |
||||
cache := remotecache.NewFakeStore(t) |
||||
|
||||
auth, reqCtx := prepareMiddleware(t, cache, nil) |
||||
auth.cfg.LDAPAuthEnabled = true |
||||
ldapFake := &service.LDAPFakeService{ |
||||
ExpectedUser: nil, |
||||
ExpectedError: service.ErrUnableToCreateLDAPClient, |
||||
} |
||||
|
||||
auth.ldapService = ldapFake |
||||
|
||||
gotID, err := auth.Login(reqCtx, false) |
||||
require.EqualError(t, err, "failed to get the user") |
||||
|
||||
assert.NotEqual(t, id, gotID) |
||||
assert.True(t, ldapFake.UserCalled) |
||||
}) |
||||
} |
||||
|
||||
func TestDecodeHeader(t *testing.T) { |
||||
cache := remotecache.NewFakeStore(t) |
||||
t.Run("should not decode header if not enabled in settings", func(t *testing.T) { |
||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { |
||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" |
||||
cfg.AuthProxyHeadersEncoded = false |
||||
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen") |
||||
}) |
||||
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) |
||||
assert.Equal(t, "M=C3=BCnchen", header) |
||||
}) |
||||
|
||||
t.Run("should decode header if enabled in settings", func(t *testing.T) { |
||||
auth, reqCtx := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) { |
||||
cfg.AuthProxyHeaderName = "X-WEBAUTH-USER" |
||||
cfg.AuthProxyHeadersEncoded = true |
||||
req.Header.Set(cfg.AuthProxyHeaderName, "M=C3=BCnchen") |
||||
}) |
||||
|
||||
header := auth.getDecodedHeader(reqCtx, auth.cfg.AuthProxyHeaderName) |
||||
assert.Equal(t, "München", header) |
||||
}) |
||||
} |
||||
@ -1,123 +1,181 @@ |
||||
package contexthandler |
||||
package contexthandler_test |
||||
|
||||
import ( |
||||
"context" |
||||
"net" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"errors" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend/gtime" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
|
||||
"github.com/grafana/grafana/pkg/infra/log" |
||||
"github.com/grafana/grafana/pkg/api/routing" |
||||
"github.com/grafana/grafana/pkg/infra/tracing" |
||||
"github.com/grafana/grafana/pkg/services/auth" |
||||
"github.com/grafana/grafana/pkg/services/auth/authtest" |
||||
"github.com/grafana/grafana/pkg/services/authn" |
||||
"github.com/grafana/grafana/pkg/services/authn/authntest" |
||||
"github.com/grafana/grafana/pkg/services/contexthandler" |
||||
contextmodel "github.com/grafana/grafana/pkg/services/contexthandler/model" |
||||
"github.com/grafana/grafana/pkg/util" |
||||
"github.com/grafana/grafana/pkg/web" |
||||
"github.com/grafana/grafana/pkg/services/featuremgmt" |
||||
"github.com/grafana/grafana/pkg/services/login" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/web/webtest" |
||||
) |
||||
|
||||
func TestDontRotateTokensOnCancelledRequests(t *testing.T) { |
||||
ctxHdlr := getContextHandler(t) |
||||
tryRotateCallCount := 0 |
||||
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{ |
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, |
||||
userAgent string) (bool, *auth.UserToken, error) { |
||||
tryRotateCallCount++ |
||||
return false, nil, nil |
||||
}, |
||||
} |
||||
|
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
reqContext, _, err := initTokenRotationScenario(ctx, t, ctxHdlr) |
||||
require.NoError(t, err) |
||||
reqContext.UserToken = &auth.UserToken{AuthToken: "oldtoken"} |
||||
|
||||
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext) |
||||
cancel() |
||||
fn(reqContext.Resp) |
||||
|
||||
assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted") |
||||
} |
||||
func TestContextHandler(t *testing.T) { |
||||
t.Run("should set auth error if authentication was unsuccessful", func(t *testing.T) { |
||||
handler := contexthandler.ProvideService( |
||||
setting.NewCfg(), |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedErr: errors.New("some error")}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) { |
||||
require.False(t, c.IsSignedIn) |
||||
require.EqualValues(t, &user.SignedInUser{Permissions: map[int64]map[string][]string{}}, c.SignedInUser) |
||||
require.Error(t, c.LookupTokenErr) |
||||
}) |
||||
|
||||
_, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
}) |
||||
|
||||
func TestTokenRotationAtEndOfRequest(t *testing.T) { |
||||
ctxHdlr := getContextHandler(t) |
||||
ctxHdlr.AuthTokenService = &authtest.FakeUserAuthTokenService{ |
||||
TryRotateTokenProvider: func(ctx context.Context, token *auth.UserToken, clientIP net.IP, |
||||
userAgent string) (bool, *auth.UserToken, error) { |
||||
newToken, err := util.RandomHex(16) |
||||
require.NoError(t, err) |
||||
token.AuthToken = newToken |
||||
return true, token, nil |
||||
}, |
||||
} |
||||
|
||||
reqContext, rr, err := initTokenRotationScenario(context.Background(), t, ctxHdlr) |
||||
require.NoError(t, err) |
||||
reqContext.UserToken = &auth.UserToken{AuthToken: "oldtoken"} |
||||
|
||||
ctxHdlr.rotateEndOfRequestFunc(reqContext)(reqContext.Resp) |
||||
foundLoginCookie := false |
||||
// nolint:bodyclose
|
||||
resp := rr.Result() |
||||
t.Cleanup(func() { |
||||
err := resp.Body.Close() |
||||
assert.NoError(t, err) |
||||
t.Run("should set identity on successful authentication", func(t *testing.T) { |
||||
identity := &authn.Identity{ID: authn.NamespacedID(authn.NamespaceUser, 1), OrgID: 1} |
||||
handler := contexthandler.ProvideService( |
||||
setting.NewCfg(), |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedIdentity: identity}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) { |
||||
require.True(t, c.IsSignedIn) |
||||
require.EqualValues(t, identity.SignedInUser(), c.SignedInUser) |
||||
require.NoError(t, c.LookupTokenErr) |
||||
}) |
||||
|
||||
_, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
}) |
||||
for _, c := range resp.Cookies() { |
||||
if c.Name == "login_token" { |
||||
foundLoginCookie = true |
||||
require.NotEqual(t, reqContext.UserToken.AuthToken, c.Value, "Auth token is still the same") |
||||
} |
||||
} |
||||
|
||||
assert.True(t, foundLoginCookie, "Could not find cookie") |
||||
} |
||||
t.Run("should not set IsSignedIn on anonymous identity", func(t *testing.T) { |
||||
identity := &authn.Identity{IsAnonymous: true, OrgID: 1} |
||||
handler := contexthandler.ProvideService( |
||||
setting.NewCfg(), |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedIdentity: identity}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) { |
||||
require.False(t, c.IsSignedIn) |
||||
require.EqualValues(t, identity.SignedInUser(), c.SignedInUser) |
||||
require.NoError(t, c.LookupTokenErr) |
||||
}) |
||||
|
||||
_, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
}) |
||||
|
||||
func initTokenRotationScenario(ctx context.Context, t *testing.T, ctxHdlr *ContextHandler) ( |
||||
*contextmodel.ReqContext, *httptest.ResponseRecorder, error) { |
||||
t.Helper() |
||||
|
||||
ctxHdlr.Cfg.LoginCookieName = "login_token" |
||||
var err error |
||||
ctxHdlr.Cfg.LoginMaxLifetime, err = gtime.ParseDuration("7d") |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
|
||||
rr := httptest.NewRecorder() |
||||
req, err := http.NewRequestWithContext(ctx, "", "", nil) |
||||
if err != nil { |
||||
return nil, nil, err |
||||
} |
||||
reqContext := &contextmodel.ReqContext{ |
||||
Context: &web.Context{Req: req}, |
||||
Logger: log.New("testlogger"), |
||||
} |
||||
|
||||
mw := mockWriter{rr} |
||||
reqContext.Resp = mw |
||||
|
||||
return reqContext, rr, nil |
||||
} |
||||
t.Run("should set IsRenderCall when authenticated by render client", func(t *testing.T) { |
||||
identity := &authn.Identity{OrgID: 1, AuthenticatedBy: login.RenderModule} |
||||
handler := contexthandler.ProvideService( |
||||
setting.NewCfg(), |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedIdentity: identity}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) { |
||||
require.True(t, c.IsSignedIn) |
||||
require.True(t, c.IsRenderCall) |
||||
require.EqualValues(t, identity.SignedInUser(), c.SignedInUser) |
||||
require.NoError(t, c.LookupTokenErr) |
||||
}) |
||||
|
||||
_, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
}) |
||||
|
||||
type mockWriter struct { |
||||
*httptest.ResponseRecorder |
||||
} |
||||
t.Run("should delete session cookie on invalid session", func(t *testing.T) { |
||||
handler := contexthandler.ProvideService( |
||||
setting.NewCfg(), |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedErr: auth.ErrInvalidSessionToken}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) {}) |
||||
|
||||
res, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
cookies := res.Cookies() |
||||
require.Len(t, cookies, 1) |
||||
require.Equal(t, cookies[0].String(), "grafana_session_expiry=; Path=/; Max-Age=0") |
||||
require.NoError(t, res.Body.Close()) |
||||
}) |
||||
|
||||
func (mw mockWriter) Flush() {} |
||||
func (mw mockWriter) Status() int { return 0 } |
||||
func (mw mockWriter) Size() int { return 0 } |
||||
func (mw mockWriter) Written() bool { return false } |
||||
func (mw mockWriter) Before(web.BeforeFunc) {} |
||||
func (mw mockWriter) Push(target string, opts *http.PushOptions) error { |
||||
return nil |
||||
} |
||||
func (mw mockWriter) CloseNotify() <-chan bool { |
||||
return make(<-chan bool) |
||||
} |
||||
func (mw mockWriter) Unwrap() http.ResponseWriter { |
||||
return mw |
||||
t.Run("should delete session cookie when oauth token refresh failed", func(t *testing.T) { |
||||
handler := contexthandler.ProvideService( |
||||
setting.NewCfg(), |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedErr: authn.ErrExpiredAccessToken.Errorf("")}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) {}) |
||||
|
||||
res, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
cookies := res.Cookies() |
||||
require.Len(t, cookies, 1) |
||||
require.Equal(t, cookies[0].String(), "grafana_session_expiry=; Path=/; Max-Age=0") |
||||
require.NoError(t, res.Body.Close()) |
||||
}) |
||||
|
||||
t.Run("should store auth header in context", func(t *testing.T) { |
||||
cfg := setting.NewCfg() |
||||
cfg.JWTAuthEnabled = true |
||||
cfg.JWTAuthHeaderName = "jwt-header" |
||||
cfg.AuthProxyEnabled = true |
||||
cfg.AuthProxyHeaderName = "proxy-header" |
||||
cfg.AuthProxyHeaders = map[string]string{ |
||||
"name": "proxy-header-name", |
||||
} |
||||
|
||||
handler := contexthandler.ProvideService( |
||||
cfg, |
||||
tracing.NewFakeTracer(), |
||||
featuremgmt.WithFeatures(), |
||||
&authntest.FakeService{ExpectedIdentity: &authn.Identity{}}, |
||||
) |
||||
|
||||
server := webtest.NewServer(t, routing.NewRouteRegister()) |
||||
server.Mux.Use(handler.Middleware) |
||||
server.Mux.Get("/api/handler", func(c *contextmodel.ReqContext) { |
||||
list := contexthandler.AuthHTTPHeaderListFromContext(c.Req.Context()) |
||||
require.NotNil(t, list) |
||||
|
||||
assert.Contains(t, list.Items, "jwt-header") |
||||
assert.Contains(t, list.Items, "proxy-header") |
||||
assert.Contains(t, list.Items, "proxy-header-name") |
||||
assert.Contains(t, list.Items, "Authorization") |
||||
}) |
||||
|
||||
_, err := server.Send(server.NewGetRequest("/api/handler")) |
||||
require.NoError(t, err) |
||||
}) |
||||
} |
||||
|
||||
Loading…
Reference in new issue