The open and composable observability and data visualization platform. Visualize metrics, logs, and traces from multiple sources like Prometheus, Loki, Elasticsearch, InfluxDB, Postgres and many more.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
grafana/pkg/services/auth/authimpl/external_session_store.go

283 lines
6.8 KiB

package authimpl
import (
"context"
"crypto/sha256"
"encoding/base64"
"github.com/grafana/grafana/pkg/infra/db"
"github.com/grafana/grafana/pkg/infra/tracing"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/secrets"
)
var _ auth.ExternalSessionStore = (*store)(nil)
type store struct {
sqlStore db.DB
secretsService secrets.Service
tracer tracing.Tracer
}
func provideExternalSessionStore(sqlStore db.DB, secretService secrets.Service, tracer tracing.Tracer) auth.ExternalSessionStore {
return &store{
sqlStore: sqlStore,
secretsService: secretService,
tracer: tracer,
}
}
func (s *store) Get(ctx context.Context, ID int64) (*auth.ExternalSession, error) {
ctx, span := s.tracer.Start(ctx, "externalsession.Get")
defer span.End()
externalSession := &auth.ExternalSession{ID: ID}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
found, err := sess.Get(externalSession)
if err != nil {
return err
}
if !found {
return auth.ErrExternalSessionNotFound
}
return nil
})
if err != nil {
return nil, err
}
err = s.decryptSecrets(externalSession)
if err != nil {
return nil, err
}
return externalSession, nil
}
func (s *store) List(ctx context.Context, query *auth.ListExternalSessionQuery) ([]*auth.ExternalSession, error) {
ctx, span := s.tracer.Start(ctx, "externalsession.List")
defer span.End()
externalSession := &auth.ExternalSession{}
if query.ID != 0 {
externalSession.ID = query.ID
}
hash := sha256.New()
if query.SessionID != "" {
hash.Write([]byte(query.SessionID))
externalSession.SessionIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
if query.NameID != "" {
hash.Reset()
hash.Write([]byte(query.NameID))
externalSession.NameIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
queryResult := make([]*auth.ExternalSession, 0)
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
return sess.Find(&queryResult, externalSession)
})
if err != nil {
return nil, err
}
for _, extSession := range queryResult {
err := s.decryptSecrets(extSession)
if err != nil {
return nil, err
}
}
return queryResult, nil
}
func (s *store) Create(ctx context.Context, extSession *auth.ExternalSession) error {
ctx, span := s.tracer.Start(ctx, "externalsession.Create")
defer span.End()
var err error
clone := extSession.Clone()
clone.AccessToken, err = s.encryptAndEncode(extSession.AccessToken)
if err != nil {
return err
}
clone.RefreshToken, err = s.encryptAndEncode(extSession.RefreshToken)
if err != nil {
return err
}
clone.IDToken, err = s.encryptAndEncode(extSession.IDToken)
if err != nil {
return err
}
if extSession.NameID != "" {
hash := sha256.New()
hash.Write([]byte(extSession.NameID))
clone.NameIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
clone.NameID, err = s.encryptAndEncode(extSession.NameID)
if err != nil {
return err
}
if extSession.SessionID != "" {
hash := sha256.New()
hash.Write([]byte(extSession.SessionID))
clone.SessionIDHash = base64.RawStdEncoding.EncodeToString(hash.Sum(nil))
}
clone.SessionID, err = s.encryptAndEncode(extSession.SessionID)
if err != nil {
return err
}
err = s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Insert(clone)
return err
})
if err != nil {
return err
}
extSession.ID = clone.ID
return nil
}
func (s *store) Update(ctx context.Context, ID int64, cmd *auth.UpdateExternalSessionCommand) error {
ctx, span := s.tracer.Start(ctx, "externalsession.Update")
defer span.End()
var err error
externalSession := &auth.ExternalSession{}
externalSession.AccessToken, err = s.encryptAndEncode(cmd.Token.AccessToken)
if err != nil {
return err
}
externalSession.RefreshToken, err = s.encryptAndEncode(cmd.Token.RefreshToken)
if err != nil {
return err
}
var secretIdToken string
if idToken, ok := cmd.Token.Extra("id_token").(string); ok && idToken != "" {
secretIdToken, err = s.encryptAndEncode(idToken)
if err != nil {
return err
}
externalSession.IDToken = secretIdToken
}
externalSession.ExpiresAt = cmd.Token.Expiry
err = s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.ID(ID).Cols("access_token", "refresh_token", "id_token", "expires_at").Update(externalSession)
return err
})
if err != nil {
return err
}
return nil
}
func (s *store) Delete(ctx context.Context, ID int64) error {
ctx, span := s.tracer.Start(ctx, "externalsession.Delete")
defer span.End()
externalSession := &auth.ExternalSession{ID: ID}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Delete(externalSession)
return err
})
return err
}
func (s *store) DeleteExternalSessionsByUserID(ctx context.Context, userID int64) error {
ctx, span := s.tracer.Start(ctx, "externalsession.DeleteExternalSessionsByUserID")
defer span.End()
externalSession := &auth.ExternalSession{UserID: userID}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.Delete(externalSession)
return err
})
return err
}
func (s *store) BatchDeleteExternalSessionsByUserIDs(ctx context.Context, userIDs []int64) error {
ctx, span := s.tracer.Start(ctx, "externalsession.BatchDeleteExternalSessionsByUserIDs")
defer span.End()
externalSession := &auth.ExternalSession{}
err := s.sqlStore.WithDbSession(ctx, func(sess *db.Session) error {
_, err := sess.In("user_id", userIDs).Delete(externalSession)
return err
})
return err
}
func (s *store) decryptSecrets(extSession *auth.ExternalSession) error {
var err error
extSession.AccessToken, err = s.decodeAndDecrypt(extSession.AccessToken)
if err != nil {
return err
}
extSession.RefreshToken, err = s.decodeAndDecrypt(extSession.RefreshToken)
if err != nil {
return err
}
extSession.IDToken, err = s.decodeAndDecrypt(extSession.IDToken)
if err != nil {
return err
}
extSession.NameID, err = s.decodeAndDecrypt(extSession.NameID)
if err != nil {
return err
}
extSession.SessionID, err = s.decodeAndDecrypt(extSession.SessionID)
if err != nil {
return err
}
return nil
}
func (s *store) encryptAndEncode(str string) (string, error) {
if str == "" {
return "", nil
}
encrypted, err := s.secretsService.Encrypt(context.Background(), []byte(str), secrets.WithoutScope())
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
func (s *store) decodeAndDecrypt(str string) (string, error) {
// Bail out if empty string since it'll cause a segfault in Decrypt
if str == "" {
return "", nil
}
decoded, err := base64.StdEncoding.DecodeString(str)
if err != nil {
return "", err
}
decrypted, err := s.secretsService.Decrypt(context.Background(), decoded)
if err != nil {
return "", err
}
return string(decrypted), nil
}