mirror of https://github.com/grafana/grafana
refactor public dashboards middleware testing (#55706)
This PR refactors how we add the orgId to the context on a public dashboard paths. We also split out accessToken handling into its own package and rework status code for "RequiresValidAccessToken". We will be modeling all endpoints to use these status codes going forward. Additionally, it includes a scaffold for better middleware testing and refactors existing tests to table drive tests.pull/55878/head
parent
609abf00d1
commit
331110bde5
@ -1,105 +1,187 @@ |
||||
package api |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"testing" |
||||
|
||||
"github.com/grafana/grafana-plugin-sdk-go/backend" |
||||
"errors" |
||||
|
||||
"github.com/grafana/grafana/pkg/models" |
||||
"github.com/grafana/grafana/pkg/services/contexthandler/ctxkey" |
||||
fakeDatasources "github.com/grafana/grafana/pkg/services/datasources/fakes" |
||||
"github.com/grafana/grafana/pkg/services/publicdashboards" |
||||
publicdashboardsService "github.com/grafana/grafana/pkg/services/publicdashboards/service" |
||||
"github.com/grafana/grafana/pkg/services/query" |
||||
"github.com/grafana/grafana/pkg/setting" |
||||
"github.com/grafana/grafana/pkg/services/publicdashboards/internal/tokens" |
||||
"github.com/grafana/grafana/pkg/services/user" |
||||
"github.com/grafana/grafana/pkg/web" |
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/mock" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestRequiresValidAccessToken(t *testing.T) { |
||||
t.Run("Returns 404 when access token is empty", func(t *testing.T) { |
||||
request, err := http.NewRequest("GET", "/api/public/ma/events/", nil) |
||||
require.NoError(t, err) |
||||
|
||||
resp := runMiddleware(request, mockAccessTokenExistsResponse(false, nil)) |
||||
|
||||
require.Equal(t, http.StatusNotFound, resp.Code) |
||||
}) |
||||
|
||||
t.Run("Returns 200 when public dashboard with access token exists", func(t *testing.T) { |
||||
request, err := http.NewRequest("GET", "/api/public/ma/events/myAccessToken", nil) |
||||
require.NoError(t, err) |
||||
|
||||
resp := runMiddleware(request, mockAccessTokenExistsResponse(true, nil)) |
||||
|
||||
require.Equal(t, http.StatusOK, resp.Code) |
||||
}) |
||||
|
||||
t.Run("Returns 400 when public dashboard with access token does not exist", func(t *testing.T) { |
||||
request, err := http.NewRequest("GET", "/api/public/ma/events/myAccessToken", nil) |
||||
require.NoError(t, err) |
||||
|
||||
resp := runMiddleware(request, mockAccessTokenExistsResponse(false, nil)) |
||||
|
||||
require.Equal(t, http.StatusBadRequest, resp.Code) |
||||
}) |
||||
|
||||
t.Run("Returns 500 when public dashboard service gives an error", func(t *testing.T) { |
||||
request, err := http.NewRequest("GET", "/api/public/ma/events/myAccessToken", nil) |
||||
require.NoError(t, err) |
||||
|
||||
resp := runMiddleware(request, mockAccessTokenExistsResponse(false, fmt.Errorf("error not found"))) |
||||
var validAccessToken, _ = tokens.GenerateAccessToken() |
||||
|
||||
require.Equal(t, http.StatusInternalServerError, resp.Code) |
||||
}) |
||||
func TestRequiresValidAccessToken(t *testing.T) { |
||||
tests := []struct { |
||||
Name string |
||||
Path string |
||||
AccessTokenExists bool |
||||
AccessTokenExistsErr error |
||||
AccessToken string |
||||
ExpectedResponseCode int |
||||
}{ |
||||
{ |
||||
Name: "Returns 200 when public dashboard with access token exists", |
||||
Path: "/api/public/ma/events/myAccesstoken", |
||||
AccessTokenExists: true, |
||||
AccessTokenExistsErr: nil, |
||||
AccessToken: validAccessToken, |
||||
ExpectedResponseCode: http.StatusOK, |
||||
}, |
||||
{ |
||||
Name: "Returns 400 when access token is empty", |
||||
Path: "/api/public/ma/events/", |
||||
AccessTokenExists: false, |
||||
AccessTokenExistsErr: nil, |
||||
AccessToken: "", |
||||
ExpectedResponseCode: http.StatusBadRequest, |
||||
}, |
||||
{ |
||||
Name: "Returns 400 when invalid access token", |
||||
Path: "/api/public/ma/events/myAccesstoken", |
||||
AccessTokenExists: false, |
||||
AccessTokenExistsErr: nil, |
||||
AccessToken: "invalidAccessToken", |
||||
ExpectedResponseCode: http.StatusBadRequest, |
||||
}, |
||||
{ |
||||
Name: "Returns 404 when public dashboard with access token does not exist", |
||||
Path: "/api/public/ma/events/myAccesstoken", |
||||
AccessTokenExists: false, |
||||
AccessTokenExistsErr: nil, |
||||
AccessToken: validAccessToken, |
||||
ExpectedResponseCode: http.StatusNotFound, |
||||
}, |
||||
{ |
||||
Name: "Returns 500 when public dashboard service gives an error", |
||||
Path: "/api/public/ma/events/myAccesstoken", |
||||
AccessTokenExists: false, |
||||
AccessTokenExistsErr: fmt.Errorf("error not found"), |
||||
AccessToken: validAccessToken, |
||||
ExpectedResponseCode: http.StatusInternalServerError, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.Name, func(t *testing.T) { |
||||
publicdashboardService := &publicdashboards.FakePublicDashboardService{} |
||||
publicdashboardService.On("AccessTokenExists", mock.Anything, mock.Anything).Return(tt.AccessTokenExists, tt.AccessTokenExistsErr) |
||||
params := map[string]string{":accessToken": tt.AccessToken} |
||||
mw := RequiresValidAccessToken(publicdashboardService) |
||||
_, resp := runMw(t, nil, "GET", tt.Path, params, mw) |
||||
require.Equal(t, tt.ExpectedResponseCode, resp.Code) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func mockAccessTokenExistsResponse(returnArguments ...interface{}) *publicdashboardsService.PublicDashboardServiceImpl { |
||||
fakeStore := &publicdashboards.FakePublicDashboardStore{} |
||||
fakeStore.On("AccessTokenExists", mock.Anything, mock.Anything).Return(returnArguments[0], returnArguments[1]) |
||||
|
||||
qds := query.ProvideService( |
||||
nil, |
||||
nil, |
||||
nil, |
||||
&fakePluginRequestValidator{}, |
||||
&fakeDatasources.FakeDataSourceService{}, |
||||
&fakePluginClient{ |
||||
QueryDataHandlerFunc: func(ctx context.Context, req *backend.QueryDataRequest) (*backend.QueryDataResponse, error) { |
||||
resp := backend.Responses{ |
||||
"A": backend.DataResponse{ |
||||
Error: fmt.Errorf("query failed"), |
||||
}, |
||||
} |
||||
return &backend.QueryDataResponse{Responses: resp}, nil |
||||
}, |
||||
func TestSetPublicDashboardOrgIdOnContext(t *testing.T) { |
||||
tests := []struct { |
||||
Name string |
||||
AccessToken string |
||||
OrgIdResp int64 |
||||
ErrorResp error |
||||
ExpectedOrgId int64 |
||||
}{ |
||||
{ |
||||
Name: "Adds orgId for enabled public dashboard", |
||||
AccessToken: validAccessToken, |
||||
OrgIdResp: 7, |
||||
ErrorResp: nil, |
||||
ExpectedOrgId: 7, |
||||
}, |
||||
&fakeOAuthTokenService{}, |
||||
) |
||||
|
||||
return publicdashboardsService.ProvideService(setting.NewCfg(), fakeStore, qds) |
||||
{ |
||||
Name: "Does not set orgId or fail with invalid accessToken", |
||||
AccessToken: "invalidAccessToken", |
||||
OrgIdResp: 0, |
||||
ErrorResp: nil, |
||||
ExpectedOrgId: 0, |
||||
}, |
||||
{ |
||||
Name: "Does not set orgId or fail with disabled public dashboard", |
||||
AccessToken: validAccessToken, |
||||
OrgIdResp: 0, |
||||
ErrorResp: nil, |
||||
ExpectedOrgId: 0, |
||||
}, |
||||
{ |
||||
Name: "Does not set orgId or fail with error querying public dashboard", |
||||
AccessToken: validAccessToken, |
||||
OrgIdResp: 0, |
||||
ErrorResp: errors.New("database error of some sort"), |
||||
ExpectedOrgId: 0, |
||||
}, |
||||
{ |
||||
Name: "Does not set orgId or fail with missing public dashboard", |
||||
AccessToken: validAccessToken, |
||||
OrgIdResp: 0, |
||||
ErrorResp: nil, |
||||
ExpectedOrgId: 0, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.Name, func(t *testing.T) { |
||||
publicdashboardService := &publicdashboards.FakePublicDashboardService{} |
||||
publicdashboardService.On("GetPublicDashboardOrgId", mock.Anything, tt.AccessToken).Return( |
||||
tt.OrgIdResp, |
||||
tt.ErrorResp, |
||||
) |
||||
|
||||
params := map[string]string{":accessToken": tt.AccessToken} |
||||
mw := SetPublicDashboardOrgIdOnContext(publicdashboardService) |
||||
ctx, _ := runMw(t, nil, "GET", "/public-dashboard/myaccesstoken", params, mw) |
||||
assert.Equal(t, tt.ExpectedOrgId, ctx.OrgID) |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func runMiddleware(request *http.Request, pubdashService *publicdashboardsService.PublicDashboardServiceImpl) *httptest.ResponseRecorder { |
||||
recorder := httptest.NewRecorder() |
||||
m := web.New() |
||||
initCtx := &models.ReqContext{} |
||||
m.Use(func(c *web.Context) { |
||||
initCtx.Context = c |
||||
c.Req = c.Req.WithContext(ctxkey.Set(c.Req.Context(), initCtx)) |
||||
func TestSetPublicDashboardFlag(t *testing.T) { |
||||
t.Run("Adds context.IsPublicDashboardView=true to request", func(t *testing.T) { |
||||
ctx := &models.ReqContext{} |
||||
SetPublicDashboardFlag(ctx) |
||||
assert.True(t, ctx.IsPublicDashboardView) |
||||
}) |
||||
m.Get("/api/public/ma/events/:accessToken", RequiresValidAccessToken(pubdashService), mockValidRequestHandler) |
||||
m.ServeHTTP(recorder, request) |
||||
|
||||
return recorder |
||||
} |
||||
|
||||
func mockValidRequestHandler(c *models.ReqContext) { |
||||
resp := make(map[string]interface{}) |
||||
resp["message"] = "Valid request" |
||||
c.JSON(http.StatusOK, resp) |
||||
// This is a helper to test middleware. It handles creating a
|
||||
// proper models.ReqContext, setting web parameters, executing middleware, and
|
||||
// returning a response. Response will default to result of
|
||||
// httptest.NewRecorder() return value and will only change if modified by the
|
||||
// middlware as this will no accept a handler method
|
||||
func runMw(t *testing.T, ctx *models.ReqContext, httpmethod string, path string, webparams map[string]string, mw func(c *models.ReqContext)) (*models.ReqContext, *httptest.ResponseRecorder) { |
||||
// create valid request context and set 0 values if they don't exist
|
||||
if ctx == nil { |
||||
ctx = &models.ReqContext{} |
||||
} |
||||
if ctx.Context == nil { |
||||
ctx.Context = &web.Context{} |
||||
} |
||||
if ctx.SignedInUser == nil { |
||||
ctx.SignedInUser = &user.SignedInUser{} |
||||
} |
||||
|
||||
// create request and add params
|
||||
request, err := http.NewRequest(httpmethod, path, nil) |
||||
require.NoError(t, err) |
||||
request = web.SetURLParams(request, webparams) |
||||
ctx.Req = request |
||||
|
||||
// setup response recorder to return
|
||||
response := httptest.NewRecorder() |
||||
ctx.Context.Resp = web.NewResponseWriter("GET", response) |
||||
|
||||
// run middleware
|
||||
mw(ctx) |
||||
|
||||
// return result
|
||||
return ctx, response |
||||
} |
||||
|
||||
@ -0,0 +1,23 @@ |
||||
package tokens |
||||
|
||||
import ( |
||||
"fmt" |
||||
|
||||
"github.com/google/uuid" |
||||
) |
||||
|
||||
// generates a uuid formatted without dashes to use as access token
|
||||
func GenerateAccessToken() (string, error) { |
||||
token, err := uuid.NewRandom() |
||||
if err != nil { |
||||
return "", err |
||||
} |
||||
|
||||
return fmt.Sprintf("%x", token[:]), nil |
||||
} |
||||
|
||||
// asserts that an accessToken is a valid uuid
|
||||
func IsValidAccessToken(token string) bool { |
||||
_, err := uuid.Parse(token) |
||||
return err == nil |
||||
} |
||||
@ -0,0 +1,38 @@ |
||||
package tokens |
||||
|
||||
import ( |
||||
"strings" |
||||
"testing" |
||||
|
||||
"github.com/stretchr/testify/assert" |
||||
"github.com/stretchr/testify/require" |
||||
) |
||||
|
||||
func TestGenerateAccessToken(t *testing.T) { |
||||
accessToken, err := GenerateAccessToken() |
||||
|
||||
t.Run("length", func(t *testing.T) { |
||||
require.NoError(t, err) |
||||
assert.Equal(t, 32, len(accessToken)) |
||||
}) |
||||
|
||||
t.Run("no - ", func(t *testing.T) { |
||||
assert.False(t, strings.Contains("-", accessToken)) |
||||
}) |
||||
} |
||||
|
||||
func TestValidAccessToken(t *testing.T) { |
||||
t.Run("true", func(t *testing.T) { |
||||
uuid, _ := GenerateAccessToken() |
||||
assert.True(t, IsValidAccessToken(uuid)) |
||||
}) |
||||
|
||||
t.Run("false when blank", func(t *testing.T) { |
||||
assert.False(t, IsValidAccessToken("")) |
||||
}) |
||||
|
||||
t.Run("false when can't be parsed by uuid lib", func(t *testing.T) { |
||||
// too long
|
||||
assert.False(t, IsValidAccessToken("0123456789012345678901234567890123456789")) |
||||
}) |
||||
} |
||||
Loading…
Reference in new issue