Chore: Add snowflake xorm tag (#104300)

* add tag handler for snowflake ids

* add snowflake generator

* fill snowflake id back to the bean

* table driven test, mockable snowflake generator

* use math/rand/v2

* snowflake without time.sleep

* more explicit bitwise modulo

* rename snowflake to randomid
pull/105039/head
Serge Zaitsev 3 months ago committed by GitHub
parent 1877b671cb
commit bf918976b2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 5
      pkg/util/xorm/core/core.go
  2. 1
      pkg/util/xorm/engine.go
  3. 70
      pkg/util/xorm/session_insert.go
  4. 7
      pkg/util/xorm/tag.go
  5. 2
      pkg/util/xorm/xorm.go
  6. 40
      pkg/util/xorm/xorm_test.go

@ -34,6 +34,7 @@ type Column struct {
Indexes map[string]int
IsPrimaryKey bool
IsAutoIncrement bool
IsRandomID bool
IsCreated bool
IsUpdated bool
IsDeleted bool
@ -1580,6 +1581,7 @@ type Table struct {
Indexes map[string]*Index
PrimaryKeys []string
AutoIncrement string
RandomID string
Created map[string]bool
Updated string
Deleted string
@ -1695,6 +1697,9 @@ func (table *Table) AddColumn(col *Column) {
if col.IsAutoIncrement {
table.AutoIncrement = col.Name
}
if col.IsRandomID {
table.RandomID = col.Name
}
if col.IsCreated {
table.Created[col.Name] = true
}

@ -44,6 +44,7 @@ type Engine struct {
defaultContext context.Context
sequenceGenerator SequenceGenerator // If not nil, this generator is used to generate auto-increment values for inserts.
randomIDGen func() int64
}
// CondDeleted returns the conditions whether a record is soft deleted.

@ -8,9 +8,12 @@ import (
"errors"
"fmt"
"reflect"
"slices"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/grafana/grafana/pkg/util/xorm/core"
"xorm.io/builder"
@ -306,13 +309,7 @@ func (session *Session) innerInsert(bean any) (int64, error) {
// If engine has a sequence number generator, use it to produce values for auto-increment columns.
if len(table.AutoIncrement) > 0 && session.engine.sequenceGenerator != nil {
var found bool
for _, col := range colNames {
if col == table.AutoIncrement {
found = true
break
}
}
found := slices.Contains(colNames, table.AutoIncrement)
if !found {
seq, err := session.engine.sequenceGenerator.Next(session.ctx, table.Name, table.AutoIncrement)
if err != nil {
@ -322,6 +319,26 @@ func (session *Session) innerInsert(bean any) (int64, error) {
colNames = append(colNames, table.AutoIncrement)
args = append(args, seq)
}
} else if len(table.RandomID) > 0 {
found := slices.Contains(colNames, table.RandomID)
if !found {
id := session.engine.randomIDGen()
colNames = append(colNames, table.RandomID)
args = append(args, id)
// Set random ID back to the bean.
col := table.GetColumn(table.RandomID)
if col == nil {
return 0, fmt.Errorf("column %s not found in table %s", table.RandomID, table.Name)
}
idValue, err := col.ValueOf(bean)
if err != nil {
session.engine.logger.Error(err)
}
if idValue == nil || !idValue.IsValid() || !idValue.CanSet() {
return 0, fmt.Errorf("failed to set snowflake ID to bean: %v", err)
}
idValue.Set(int64ToIntValue(id, idValue.Type()))
}
}
exprs := session.statement.exprColumns
@ -587,7 +604,7 @@ func (session *Session) genInsertColumns(bean any) ([]string, []any, error) {
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement {
if col.IsAutoIncrement || col.IsRandomID {
switch fieldValue.Type().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int, reflect.Int64:
if fieldValue.Int() == 0 {
@ -780,3 +797,40 @@ func (session *Session) insertMap(columns []string, args []any) (int64, error) {
}
return affected, nil
}
type snowflake struct {
mu sync.Mutex
nodeID int64
sequence int64
lastTime int64
epoch time.Time
}
// newSnowflake creates a new instance with a random node ID (0-1023)
// It forcefully converts epoch time (in milliseconds) to monotonic time
func newSnowflake(nodeID int64) *snowflake {
const snowflakeEpoch = 1288834974657 // 2010-11-04 01:42:54.657 UTC
epoch := time.Unix(snowflakeEpoch/1000, (snowflakeEpoch%1000)*1000000)
now := time.Now()
return &snowflake{nodeID: nodeID & 0x3ff, epoch: now.Add(epoch.Sub(now))}
}
func (s *snowflake) Generate() int64 {
s.mu.Lock()
defer s.mu.Unlock()
currentTime := time.Since(s.epoch).Milliseconds()
if currentTime == s.lastTime {
s.sequence = (s.sequence + 1) & 0xfff
if s.sequence == 0 {
// wait for next millisecond, we are not using time.Sleep() here due to its low resolution (often >4ms)
for currentTime <= s.lastTime {
currentTime = time.Since(s.epoch).Milliseconds()
}
}
} else {
s.sequence = 0
}
s.lastTime = currentTime
id := (currentTime << 22) | (s.nodeID << 12) | s.sequence
return id
}

@ -38,6 +38,7 @@ var (
"NULL": NULLTagHandler,
"NOT": IgnoreTagHandler,
"AUTOINCR": AutoIncrTagHandler,
"RANDOMID": RandomIDTagHandler,
"DEFAULT": DefaultTagHandler,
"CREATED": CreatedTagHandler,
"UPDATED": UpdatedTagHandler,
@ -88,6 +89,12 @@ func AutoIncrTagHandler(ctx *tagContext) error {
return nil
}
// RandomIDTagHandler describes snowflake id tag handler
func RandomIDTagHandler(ctx *tagContext) error {
ctx.col.IsRandomID = true
return nil
}
// DefaultTagHandler describes default tag handler
func DefaultTagHandler(ctx *tagContext) error {
if len(ctx.params) > 0 {

@ -11,6 +11,7 @@ import (
"context"
"database/sql"
"fmt"
"math/rand/v2"
"os"
"reflect"
"runtime"
@ -95,6 +96,7 @@ func NewEngine(driverName string, dataSourceName string) (*Engine, error) {
tagHandlers: defaultTagHandlers,
defaultContext: context.Background(),
timestampFormat: "2006-01-02 15:04:05",
randomIDGen: newSnowflake(rand.Int64N(1024)).Generate,
}
switch uri.DbType {

@ -48,3 +48,43 @@ type TestStruct struct {
Comment string
Json json.RawMessage
}
func TestRandomID(t *testing.T) {
type RandomIDRecord struct {
ID int64 `xorm:"'id' pk randomid"`
Comment string
}
eng, err := NewEngine("sqlite3", ":memory:")
require.NoError(t, err)
require.NoError(t, eng.Sync(new(RandomIDRecord)))
// Test sequence of different snowflake values
testCases := []struct {
name string
id int64
comment string
}{
{"first insert", 42, "first comment"},
{"second insert", 123, "second comment"},
{"third insert", 1337, "third comment"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
eng.randomIDGen = func() int64 { return tc.id }
obj := &RandomIDRecord{Comment: tc.comment}
_, err := eng.Insert(obj)
require.NoError(t, err)
require.Equal(t, tc.id, obj.ID, "ID should match current snowflake value")
// Verify database entry
var retrieved RandomIDRecord
has, err := eng.ID(tc.id).Get(&retrieved)
require.NoError(t, err)
require.True(t, has)
require.Equal(t, tc.comment, retrieved.Comment)
})
}
}

Loading…
Cancel
Save