From c4b4baea2a4c5e4d9975dcd47002c6bc4f7d705c Mon Sep 17 00:00:00 2001 From: Karl Persson Date: Tue, 20 Dec 2022 21:18:48 +0100 Subject: [PATCH] AuthN: set org id for authentication request in service (#60528) * AuthN: Replicate functionallity to get org id for request * Authn: parse org id for the request and populate the auth request with it * AuthN: add simple mock for client to use in test * AuthN: add tests to verify that authentication is called with correct org id * AuthN: Add ClientParams to mock * AuthN: Fix flaky org id selection --- pkg/services/authn/authn.go | 2 + pkg/services/authn/authnimpl/service.go | 53 ++++++++++++-- pkg/services/authn/authnimpl/service_test.go | 73 +++++++++++++++++++ .../authn/authnimpl/usersync/orgsync.go | 19 +++-- pkg/services/authn/authntest/mock.go | 36 +++++++++ 5 files changed, 170 insertions(+), 13 deletions(-) create mode 100644 pkg/services/authn/authntest/mock.go diff --git a/pkg/services/authn/authn.go b/pkg/services/authn/authn.go index ecd66c2bb67..8f81c437c64 100644 --- a/pkg/services/authn/authn.go +++ b/pkg/services/authn/authn.go @@ -42,6 +42,8 @@ type Client interface { } type Request struct { + // OrgID will be populated by authn.Service + OrgID int64 HTTPRequest *http.Request } diff --git a/pkg/services/authn/authnimpl/service.go b/pkg/services/authn/authnimpl/service.go index 9e40ab427da..602db60823e 100644 --- a/pkg/services/authn/authnimpl/service.go +++ b/pkg/services/authn/authnimpl/service.go @@ -2,6 +2,8 @@ package authnimpl import ( "context" + "net/http" + "strconv" "github.com/grafana/grafana/pkg/infra/log" "github.com/grafana/grafana/pkg/infra/tracing" @@ -15,6 +17,7 @@ import ( "go.opentelemetry.io/otel/attribute" ) +// make sure service implements authn.Service interface var _ authn.Service = new(Service) func ProvideService(cfg *setting.Cfg, tracer tracing.Tracer, orgService org.Service, apikeyService apikey.Service, userService user.Service) *Service { @@ -24,7 +27,6 @@ func ProvideService(cfg *setting.Cfg, tracer tracing.Tracer, orgService org.Serv clients: make(map[string]authn.Client), tracer: tracer, postAuthHooks: []authn.PostAuthHookFn{}, - userService: userService, } s.clients[authn.ClientAPIKey] = clients.ProvideAPIKey(apikeyService, userService) @@ -46,12 +48,9 @@ type Service struct { log log.Logger cfg *setting.Cfg clients map[string]authn.Client - // postAuthHooks are called after a successful authentication. They can modify the identity. postAuthHooks []authn.PostAuthHookFn - - tracer tracing.Tracer - userService user.Service + tracer tracing.Tracer } func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Request) (*authn.Identity, bool, error) { @@ -74,6 +73,7 @@ func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Requ return nil, false, nil } + r.OrgID = orgIDFromRequest(r) identity, err := c.Authenticate(ctx, r) if err != nil { logger.Warn("auth client could not authenticate request", "client", client, "error", err) @@ -103,3 +103,46 @@ func (s *Service) Authenticate(ctx context.Context, client string, r *authn.Requ func (s *Service) RegisterPostAuthHook(hook authn.PostAuthHookFn) { s.postAuthHooks = append(s.postAuthHooks, hook) } + +func orgIDFromRequest(r *authn.Request) int64 { + if r.HTTPRequest == nil { + return 0 + } + + orgID := orgIDFromQuery(r.HTTPRequest) + if orgID > 0 { + return orgID + } + + return orgIDFromHeader(r.HTTPRequest) +} + +// name of query string used to target specific org for request +const orgIDTargetQuery = "targetOrgId" + +func orgIDFromQuery(req *http.Request) int64 { + params := req.URL.Query() + if !params.Has(orgIDTargetQuery) { + return 0 + } + id, err := strconv.ParseInt(params.Get(orgIDTargetQuery), 10, 64) + if err != nil { + return 0 + } + return id +} + +// name of header containing org id for request +const orgIDHeaderName = "X-Grafana-Org-Id" + +func orgIDFromHeader(req *http.Request) int64 { + header := req.Header.Get(orgIDHeaderName) + if header == "" { + return 0 + } + id, err := strconv.ParseInt(header, 10, 64) + if err != nil { + return 0 + } + return id +} diff --git a/pkg/services/authn/authnimpl/service_test.go b/pkg/services/authn/authnimpl/service_test.go index 1255813a358..b12d49d739d 100644 --- a/pkg/services/authn/authnimpl/service_test.go +++ b/pkg/services/authn/authnimpl/service_test.go @@ -3,6 +3,8 @@ package authnimpl import ( "context" "errors" + "net/http" + "net/url" "testing" "github.com/stretchr/testify/assert" @@ -59,6 +61,77 @@ func TestService_Authenticate(t *testing.T) { } } +func TestService_AuthenticateOrgID(t *testing.T) { + type TestCase struct { + desc string + req *authn.Request + expectedOrgID int64 + } + + tests := []TestCase{ + { + desc: "should set org id when present in header", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{orgIDHeaderName: {"1"}}, + URL: &url.URL{}, + }}, + expectedOrgID: 1, + }, + { + desc: "should set org id when present in url", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: mustParseURL("http://localhost/?targetOrgId=2"), + }}, + expectedOrgID: 2, + }, + { + desc: "should prioritise org id from url when present in both header and url", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{orgIDHeaderName: {"1"}}, + URL: mustParseURL("http://localhost/?targetOrgId=2"), + }}, + expectedOrgID: 2, + }, + { + desc: "should set org id to 0 when missing in both header and url", + req: &authn.Request{HTTPRequest: &http.Request{ + Header: map[string][]string{}, + URL: &url.URL{}, + }}, + expectedOrgID: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + var calledWith int64 + s := setupTests(t, func(svc *Service) { + svc.clients["fake"] = authntest.MockClient{ + AuthenticateFunc: func(ctx context.Context, r *authn.Request) (*authn.Identity, error) { + calledWith = r.OrgID + return nil, nil + }, + TestFunc: func(ctx context.Context, r *authn.Request) bool { + return true + }, + } + }) + + _, _, _ = s.Authenticate(context.Background(), "fake", tt.req) + assert.Equal(t, tt.expectedOrgID, calledWith) + }) + } +} + +func mustParseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + func setupTests(t *testing.T, opts ...func(svc *Service)) *Service { t.Helper() diff --git a/pkg/services/authn/authnimpl/usersync/orgsync.go b/pkg/services/authn/authnimpl/usersync/orgsync.go index 0afb5cecbb1..fa2e4fa677e 100644 --- a/pkg/services/authn/authnimpl/usersync/orgsync.go +++ b/pkg/services/authn/authnimpl/usersync/orgsync.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "sort" "github.com/grafana/grafana/pkg/cmd/grafana-cli/logger" "github.com/grafana/grafana/pkg/infra/log" @@ -65,8 +66,10 @@ func (s *OrgSync) SyncOrgUser(ctx context.Context, clientParams *authn.ClientPar } } + orgIDs := make([]int64, 0, len(id.OrgRoles)) // add any new org roles for orgId, orgRole := range id.OrgRoles { + orgIDs = append(orgIDs, orgId) if _, exists := handledOrgIds[orgId]; exists { continue } @@ -99,17 +102,17 @@ func (s *OrgSync) SyncOrgUser(ctx context.Context, clientParams *authn.ClientPar } } + // Note: sort all org ids to not make it flaky, for now we default to the lowest id + sort.Slice(orgIDs, func(i, j int) bool { return orgIDs[i] < orgIDs[j] }) // update user's default org if needed if _, ok := id.OrgRoles[id.OrgID]; !ok { - for orgId := range id.OrgRoles { - id.OrgID = orgId - break + if len(orgIDs) > 0 { + id.OrgID = orgIDs[0] + return s.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{ + UserID: userID, + OrgID: id.OrgID, + }) } - - return s.userService.SetUsingOrg(ctx, &user.SetUsingOrgCommand{ - UserID: userID, - OrgID: id.OrgID, - }) } return nil diff --git a/pkg/services/authn/authntest/mock.go b/pkg/services/authn/authntest/mock.go new file mode 100644 index 00000000000..4aca8ea11cf --- /dev/null +++ b/pkg/services/authn/authntest/mock.go @@ -0,0 +1,36 @@ +package authntest + +import ( + "context" + + "github.com/grafana/grafana/pkg/services/authn" +) + +var _ authn.Client = new(MockClient) + +type MockClient struct { + AuthenticateFunc func(ctx context.Context, r *authn.Request) (*authn.Identity, error) + ClientParamsFunc func() *authn.ClientParams + TestFunc func(ctx context.Context, r *authn.Request) bool +} + +func (m MockClient) Authenticate(ctx context.Context, r *authn.Request) (*authn.Identity, error) { + if m.AuthenticateFunc != nil { + return m.AuthenticateFunc(ctx, r) + } + return nil, nil +} + +func (m MockClient) ClientParams() *authn.ClientParams { + if m.ClientParamsFunc != nil { + return m.ClientParamsFunc() + } + return nil +} + +func (m MockClient) Test(ctx context.Context, r *authn.Request) bool { + if m.TestFunc != nil { + return m.TestFunc(ctx, r) + } + return false +}