The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/login/authinfoservice/database/database.go

312 lines
8.5 KiB

package database
import (
"context"
"encoding/base64"
"time"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/services/login"
"github.com/grafana/grafana/pkg/services/secrets"
"github.com/grafana/grafana/pkg/services/user"
)
var GetTime = time.Now
type AuthInfoStore struct {
sqlStore db.DB
secretsService secrets.Service
logger log.Logger
userService user.Service
}
func ProvideAuthInfoStore(sqlStore db.DB, secretsService secrets.Service, userService user.Service) login.Store {
store := &AuthInfoStore{
sqlStore: sqlStore,
secretsService: secretsService,
logger: log.New("login.authinfo.store"),
userService: userService,
}
// FIXME: disabled the metric collection for duplicate user entries
// due to query performance issues that is clogging the users Grafana instance
// InitDuplicateUserMetrics()
return store
}
// GetAuthInfo returns the auth info for a user
// It will return the latest auth info for a user
func (s *AuthInfoStore) GetAuthInfo(ctx context.Context, query *login.GetAuthInfoQuery) (*login.UserAuth, error) {
if query.UserId == 0 && query.AuthId == "" {
return nil, user.ErrUserNotFound
}
userAuth := &login.UserAuth{
UserId: query.UserId,
AuthModule: query.AuthModule,
AuthId: query.AuthId,
}
var has bool
var err error
err = s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
has, err = sess.Desc("created").Get(userAuth)
return err
})
if err != nil {
return nil, err
}
if !has {
return nil, user.ErrUserNotFound
}
secretAccessToken, err := s.decodeAndDecrypt(userAuth.OAuthAccessToken)
if err != nil {
return nil, err
}
secretRefreshToken, err := s.decodeAndDecrypt(userAuth.OAuthRefreshToken)
if err != nil {
return nil, err
}
secretTokenType, err := s.decodeAndDecrypt(userAuth.OAuthTokenType)
if err != nil {
return nil, err
}
secretIdToken, err := s.decodeAndDecrypt(userAuth.OAuthIdToken)
if err != nil {
return nil, err
}
userAuth.OAuthAccessToken = secretAccessToken
userAuth.OAuthRefreshToken = secretRefreshToken
userAuth.OAuthTokenType = secretTokenType
userAuth.OAuthIdToken = secretIdToken
return userAuth, nil
}
func (s *AuthInfoStore) GetUserLabels(ctx context.Context, query login.GetUserLabelsQuery) (map[int64]string, error) {
userAuths := []login.UserAuth{}
params := make([]interface{}, 0, len(query.UserIDs))
for _, id := range query.UserIDs {
params = append(params, id)
}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
return sess.Table("user_auth").In("user_id", params).OrderBy("created").Find(&userAuths)
})
if err != nil {
return nil, err
}
labelMap := make(map[int64]string, len(userAuths))
for i := range userAuths {
labelMap[userAuths[i].UserId] = userAuths[i].AuthModule
}
return labelMap, nil
}
func (s *AuthInfoStore) SetAuthInfo(ctx context.Context, cmd *login.SetAuthInfoCommand) error {
authUser := &login.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
AuthId: cmd.AuthId,
Created: GetTime(),
}
if cmd.OAuthToken != nil {
secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil {
return err
}
secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil {
return err
}
secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil {
return err
}
var secretIdToken string
if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
secretIdToken, err = s.encryptAndEncode(idToken)
if err != nil {
return err
}
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthIdToken = secretIdToken
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Insert(authUser)
return err
})
}
// UpdateAuthInfoDate updates the auth info for the user with the latest date.
// Avoids overlapping entries hiding the last used one (ex: LDAP->SAML->LDAP).
func (s *AuthInfoStore) UpdateAuthInfoDate(ctx context.Context, authInfo *login.UserAuth) error {
authInfo.Created = GetTime()
cond := &login.UserAuth{
Id: authInfo.Id,
UserId: authInfo.UserId,
AuthModule: authInfo.AuthModule,
}
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Cols("created").Update(authInfo, cond)
return err
})
}
func (s *AuthInfoStore) UpdateAuthInfo(ctx context.Context, cmd *login.UpdateAuthInfoCommand) error {
authUser := &login.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
AuthId: cmd.AuthId,
Created: GetTime(),
}
if cmd.OAuthToken != nil {
secretAccessToken, err := s.encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil {
return err
}
secretRefreshToken, err := s.encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil {
return err
}
secretTokenType, err := s.encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil {
return err
}
var secretIdToken string
if idToken, ok := cmd.OAuthToken.Extra("id_token").(string); ok && idToken != "" {
secretIdToken, err = s.encryptAndEncode(idToken)
if err != nil {
return err
}
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthIdToken = secretIdToken
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
upd, err := sess.MustCols("o_auth_expiry").Where("user_id = ? AND auth_module = ?", cmd.UserId, cmd.AuthModule).Update(authUser)
s.logger.Debug("Updated user_auth", "user_id", cmd.UserId, "auth_id", cmd.AuthId, "auth_module", cmd.AuthModule, "rows", upd)
// Clean up duplicated entries
if upd > 1 {
var id int64
ok, err := sess.SQL(
"SELECT id FROM user_auth WHERE user_id = ? AND auth_module = ? AND auth_id = ?",
cmd.UserId, cmd.AuthModule, cmd.AuthId,
).Get(&id)
if err != nil {
return err
}
if !ok {
return nil
}
_, err = sess.Exec(
"DELETE FROM user_auth WHERE user_id = ? AND auth_module = ? AND auth_id = ? AND id != ?",
cmd.UserId, cmd.AuthModule, cmd.AuthId, id,
)
return err
}
return err
})
}
func (s *AuthInfoStore) DeleteAuthInfo(ctx context.Context, cmd *login.DeleteAuthInfoCommand) error {
return s.sqlStore.WithTransactionalDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Delete(cmd.UserAuth)
return err
})
}
func (s *AuthInfoStore) DeleteUserAuthInfo(ctx context.Context, userID int64) error {
return s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
var rawSQL = "DELETE FROM user_auth WHERE user_id = ?"
_, err := sess.Exec(rawSQL, userID)
return err
})
}
func (s *AuthInfoStore) GetUserById(ctx context.Context, id int64) (*user.User, error) {
query := user.GetUserByIDQuery{ID: id}
user, err := s.userService.GetByID(ctx, &query)
if err != nil {
return nil, err
}
return user, nil
}
func (s *AuthInfoStore) GetUserByLogin(ctx context.Context, login string) (*user.User, error) {
query := user.GetUserByLoginQuery{LoginOrEmail: login}
usr, err := s.userService.GetByLogin(ctx, &query)
if err != nil {
return nil, err
}
return usr, nil
}
func (s *AuthInfoStore) GetUserByEmail(ctx context.Context, email string) (*user.User, error) {
query := user.GetUserByEmailQuery{Email: email}
usr, err := s.userService.GetByEmail(ctx, &query)
if err != nil {
return nil, err
}
return usr, nil
}
// decodeAndDecrypt will decode the string with the standard base64 decoder and then decrypt it
func (s *AuthInfoStore) decodeAndDecrypt(str string) (string, error) {
// Bail out if empty string since it'll cause a segfault in Decrypt
if str == "" {
return "", nil
}
decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return "", err
}
decrypted, err := s.secretsService.Decrypt(context.Background(), decoded)
if err != nil {
return "", err
}
return string(decrypted), nil
}
// encryptAndEncode will encrypt a string with grafana's secretKey, and
// then encode it with the standard bas64 encoder
func (s *AuthInfoStore) encryptAndEncode(str string) (string, error) {
encrypted, err := s.secretsService.Encrypt(context.Background(), []byte(str), secrets.WithoutScope())
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}