diff --git a/pkg/api/login.go b/pkg/api/login.go index f66a2e2740d..06688c97297 100644 --- a/pkg/api/login.go +++ b/pkg/api/login.go @@ -10,7 +10,6 @@ import ( "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/infra/metrics" "github.com/grafana/grafana/pkg/infra/network" "github.com/grafana/grafana/pkg/login" @@ -207,7 +206,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext) response.Response { Cfg: hs.Cfg, } - err := bus.Dispatch(c.Req.Context(), authQuery) + err := login.AuthenticateUserFunc(c.Req.Context(), authQuery) authModule = authQuery.AuthModule if err != nil { resp = response.Error(401, "Invalid username or password", err) diff --git a/pkg/api/login_test.go b/pkg/api/login_test.go index e25d2cd1c82..1c126f8ad79 100644 --- a/pkg/api/login_test.go +++ b/pkg/api/login_test.go @@ -16,7 +16,6 @@ import ( "github.com/grafana/grafana/pkg/api/dtos" "github.com/grafana/grafana/pkg/api/response" "github.com/grafana/grafana/pkg/api/routing" - "github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/components/simplejson" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/login" @@ -155,7 +154,6 @@ func TestLoginErrorCookieAPIEndpoint(t *testing.T) { func TestLoginViewRedirect(t *testing.T) { fakeSetIndexViewData(t) - fakeViewIndex(t) sc := setupScenarioContext(t, "/login") cfg := setting.NewCfg() @@ -348,13 +346,13 @@ func TestLoginPostRedirect(t *testing.T) { return hs.LoginPost(c) }) - bus.AddHandler("grafana-auth", func(ctx context.Context, query *models.LoginUserQuery) error { - query.User = &models.User{ - Id: 42, - Email: "", - } - return nil - }) + user := &models.User{ + Id: 42, + Email: "", + } + + mockAuthenticateUserFunc(user, "", nil) + t.Cleanup(resetAuthenticateUserFunc) redirectCases := []redirectCase{ { @@ -441,6 +439,7 @@ func TestLoginPostRedirect(t *testing.T) { for _, c := range redirectCases { hs.Cfg.AppURL = c.appURL hs.Cfg.AppSubURL = c.appSubURL + t.Run(c.desc, func(t *testing.T) { expCookiePath := "/" if len(hs.Cfg.AppSubURL) > 0 { @@ -685,12 +684,8 @@ func TestLoginPostRunLokingHook(t *testing.T) { for _, c := range testCases { t.Run(c.desc, func(t *testing.T) { - bus.AddHandler("grafana-auth", func(ctx context.Context, query *models.LoginUserQuery) error { - query.User = c.authUser - query.AuthModule = c.authModule - return c.authErr - }) - + mockAuthenticateUserFunc(c.authUser, c.authModule, c.authErr) + t.Cleanup(resetAuthenticateUserFunc) sc.m.Post(sc.url, sc.defaultHandler) sc.fakeReqNoAssertions("POST", sc.url).exec() @@ -736,3 +731,14 @@ func (m *mockSocialService) GetOAuthHttpClient(name string) (*http.Client, error func (m *mockSocialService) GetConnector(string) (social.SocialConnector, error) { return m.socialConnector, m.err } + +func mockAuthenticateUserFunc(user *models.User, authmodule string, err error) { + login.AuthenticateUserFunc = func(ctx context.Context, query *models.LoginUserQuery) error { + query.User = user + query.AuthModule = authmodule + return err + } +} +func resetAuthenticateUserFunc() { + login.AuthenticateUserFunc = login.AuthenticateUser +} diff --git a/pkg/login/auth.go b/pkg/login/auth.go index 6e6f6dda6d6..97eb4acd62d 100644 --- a/pkg/login/auth.go +++ b/pkg/login/auth.go @@ -25,12 +25,14 @@ var ( var loginLogger = log.New("login") +var AuthenticateUserFunc = AuthenticateUser + func Init() { - bus.AddHandler("auth", authenticateUser) + bus.AddHandler("auth", AuthenticateUser) } -// authenticateUser authenticates the user via username & password -func authenticateUser(ctx context.Context, query *models.LoginUserQuery) error { +// AuthenticateUser authenticates the user via username & password +func AuthenticateUser(ctx context.Context, query *models.LoginUserQuery) error { if err := validateLoginAttempts(ctx, query); err != nil { return err } diff --git a/pkg/login/auth_test.go b/pkg/login/auth_test.go index ce3567997fb..c8f1d7ca488 100644 --- a/pkg/login/auth_test.go +++ b/pkg/login/auth_test.go @@ -21,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) { Username: "user", Password: "", } - err := authenticateUser(context.Background(), &loginQuery) + err := AuthenticateUserFunc(context.Background(), &loginQuery) require.EqualError(t, err, ErrPasswordEmpty.Error()) assert.False(t, sc.grafanaLoginWasCalled) @@ -35,7 +35,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrTooManyLoginAttempts.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -51,7 +51,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.NoError(t, err) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -68,7 +68,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.EqualError(t, err, customErr.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -84,7 +84,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(false, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.EqualError(t, err, models.ErrUserNotFound.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -100,7 +100,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrInvalidCredentials.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -116,7 +116,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, nil, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.NoError(t, err) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -133,7 +133,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, customErr, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.EqualError(t, err, customErr.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) @@ -149,7 +149,7 @@ func TestAuthenticateUser(t *testing.T) { mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc) mockSaveInvalidLoginAttempt(sc) - err := authenticateUser(context.Background(), sc.loginUserQuery) + err := AuthenticateUserFunc(context.Background(), sc.loginUserQuery) require.EqualError(t, err, ErrInvalidCredentials.Error()) assert.True(t, sc.loginAttemptValidationWasCalled) diff --git a/pkg/services/contexthandler/authproxy/authproxy_test.go b/pkg/services/contexthandler/authproxy/authproxy_test.go index 96a8b56eb9c..885d92e047a 100644 --- a/pkg/services/contexthandler/authproxy/authproxy_test.go +++ b/pkg/services/contexthandler/authproxy/authproxy_test.go @@ -19,35 +19,6 @@ import ( "github.com/stretchr/testify/require" ) -type fakeMultiLDAP struct { - multildap.MultiLDAP - ID int64 - userCalled bool - loginCalled bool -} - -func (m *fakeMultiLDAP) Login(query *models.LoginUserQuery) ( - *models.ExternalUserInfo, error, -) { - m.loginCalled = true - result := &models.ExternalUserInfo{ - UserId: m.ID, - } - return result, nil -} - -func (m *fakeMultiLDAP) User(login string) ( - *models.ExternalUserInfo, - ldap.ServerConfig, - error, -) { - m.userCalled = true - result := &models.ExternalUserInfo{ - UserId: m.ID, - } - return result, ldap.ServerConfig{}, nil -} - const hdrName = "markelog" func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, cb func(*http.Request, *setting.Cfg)) *AuthProxy { @@ -152,7 +123,7 @@ func TestMiddlewareContext_ldap(t *testing.T) { return true } - stub := &fakeMultiLDAP{ + stub := &multildap.MultiLDAPmock{ ID: id, } @@ -179,7 +150,7 @@ func TestMiddlewareContext_ldap(t *testing.T) { require.NoError(t, err) assert.Equal(t, id, gotID) - assert.True(t, stub.userCalled) + assert.True(t, stub.UserCalled) }) t.Run("Gets nice error if LDAP is enabled, but not configured", func(t *testing.T) { @@ -205,7 +176,7 @@ func TestMiddlewareContext_ldap(t *testing.T) { auth := prepareMiddleware(t, cache, nil) - stub := &fakeMultiLDAP{ + stub := &multildap.MultiLDAPmock{ ID: id, } @@ -217,6 +188,6 @@ func TestMiddlewareContext_ldap(t *testing.T) { require.EqualError(t, err, "failed to get the user") assert.NotEqual(t, id, gotID) - assert.False(t, stub.loginCalled) + assert.False(t, stub.LoginCalled) }) } diff --git a/pkg/services/multildap/multidap_mock.go b/pkg/services/multildap/multidap_mock.go new file mode 100644 index 00000000000..0c7748a9369 --- /dev/null +++ b/pkg/services/multildap/multidap_mock.go @@ -0,0 +1,40 @@ +package multildap + +import ( + "github.com/grafana/grafana/pkg/models" + "github.com/grafana/grafana/pkg/services/ldap" +) + +type MultiLDAPmock struct { + MultiLDAP + ID int64 + UserCalled bool + LoginCalled bool + UserInfo *models.User + AuthModule string + ExpectedErr error +} + +func (m *MultiLDAPmock) Login(query *models.LoginUserQuery) ( + *models.ExternalUserInfo, error, +) { + m.LoginCalled = true + query.User = m.UserInfo + query.AuthModule = m.AuthModule + result := &models.ExternalUserInfo{ + UserId: m.ID, + } + return result, m.ExpectedErr +} + +func (m *MultiLDAPmock) User(login string) ( + *models.ExternalUserInfo, + ldap.ServerConfig, + error, +) { + m.UserCalled = true + result := &models.ExternalUserInfo{ + UserId: m.ID, + } + return result, ldap.ServerConfig{}, nil +}