From 9d0d1084e39fdfd4a58d701b13ff755f4e5f36eb Mon Sep 17 00:00:00 2001 From: Robert Fratto Date: Thu, 20 Nov 2025 10:18:20 -0500 Subject: [PATCH] chore(engine): use github.com/grafana/dskit/dns for scheduler discovery (#19835) Signed-off-by: Robert Fratto --- .../internal/worker/scheduler_lookup.go | 193 ++++++++++++------ .../internal/worker/scheduler_lookup_test.go | 116 +++++------ 2 files changed, 179 insertions(+), 130 deletions(-) diff --git a/pkg/engine/internal/worker/scheduler_lookup.go b/pkg/engine/internal/worker/scheduler_lookup.go index ae0a6fad3e..179c429180 100644 --- a/pkg/engine/internal/worker/scheduler_lookup.go +++ b/pkg/engine/internal/worker/scheduler_lookup.go @@ -10,18 +10,14 @@ import ( "github.com/go-kit/log" "github.com/go-kit/log/level" + "github.com/grafana/dskit/dns" "github.com/grafana/dskit/grpcutil" ) type schedulerLookup struct { - // logger to log messages with. - logger log.Logger - - // Watcher emits events when schedulers are found. Each found address must - // be an IP address. - watcher grpcutil.Watcher - - closeOnce sync.Once + logger log.Logger + watcher *dnsWatcher + interval time.Duration } type handleScheduler func(ctx context.Context, addr net.Addr) @@ -31,29 +27,16 @@ func newSchedulerLookup(logger log.Logger, address string, lookupInterval time.D logger = log.NewNopLogger() } - resolver, err := grpcutil.NewDNSResolverWithFreq(lookupInterval, logger) - if err != nil { - return nil, fmt.Errorf("creating DNS resolver: %w", err) - } - - watcher, err := resolver.Resolve(address, "") - if err != nil { - return nil, fmt.Errorf("creating DNS watcher: %w", err) - } + provider := dns.NewProvider(logger, nil, dns.GolangResolverType) return &schedulerLookup{ - logger: logger, - watcher: watcher, + logger: logger, + watcher: newDNSWatcher(address, provider), + interval: lookupInterval, }, nil } func (l *schedulerLookup) Run(ctx context.Context, handlerFunc handleScheduler) error { - // Hook into context cancellation to close the watcher. We need to do this - // because the watcher doesn't accept a custom context when polling for - // changes. - stop := context.AfterFunc(ctx, l.closeWatcher) - defer stop() - var handlerWg sync.WaitGroup defer handlerWg.Wait() @@ -67,60 +50,69 @@ func (l *schedulerLookup) Run(ctx context.Context, handlerFunc handleScheduler) ctx, cancel := context.WithCancel(ctx) defer cancel() + // Initially set our timer with no delay so we process the first update + // immediately. We'll give it a real duration after each tick. + timer := time.NewTimer(0) + defer timer.Stop() + for { - updates, err := l.watcher.Next() - if err != nil && ctx.Err() == nil { - return fmt.Errorf("finding schedulers: %w", err) - } else if ctx.Err() != nil { - // The context was canceled, we can exit gracefully. + select { + case <-ctx.Done(): return nil - } - for _, update := range updates { - switch update.Op { - case grpcutil.Add: - if _, exist := handlers[update.Addr]; exist { - // Ignore duplicate handlers. - level.Warn(l.logger).Log("msg", "ignoring duplicate scheduler", "addr", update.Addr) - continue - } + case <-timer.C: + timer.Reset(l.interval) - addr, err := parseTCPAddr(update.Addr) - if err != nil { - level.Warn(l.logger).Log("msg", "failed to parse scheduler address", "addr", update.Addr, "err", err) - continue - } + updates, err := l.watcher.Poll(ctx) + if err != nil && ctx.Err() == nil { + return fmt.Errorf("finding schedulers: %w", err) + } else if ctx.Err() != nil { + // The context was canceled, we can exit gracefully. + return nil + } - var handler handlerContext - handler.Context, handler.Cancel = context.WithCancel(ctx) - handlers[update.Addr] = handler - - handlerWg.Add(1) - go func() { - defer handlerWg.Done() - handlerFunc(handler.Context, addr) - }() - - case grpcutil.Delete: - handler, exist := handlers[update.Addr] - if !exist { - level.Warn(l.logger).Log("msg", "ignoring unrecognized scheduler", "addr", update.Addr) - continue + for _, update := range updates { + switch update.Op { + case grpcutil.Add: + if _, exist := handlers[update.Addr]; exist { + // Ignore duplicate handlers. + level.Warn(l.logger).Log("msg", "ignoring duplicate scheduler", "addr", update.Addr) + continue + } + + addr, err := parseTCPAddr(update.Addr) + if err != nil { + level.Warn(l.logger).Log("msg", "failed to parse scheduler address", "addr", update.Addr, "err", err) + continue + } + + var handler handlerContext + handler.Context, handler.Cancel = context.WithCancel(ctx) + handlers[update.Addr] = handler + + handlerWg.Add(1) + go func() { + defer handlerWg.Done() + handlerFunc(handler.Context, addr) + }() + + case grpcutil.Delete: + handler, exist := handlers[update.Addr] + if !exist { + level.Warn(l.logger).Log("msg", "ignoring unrecognized scheduler", "addr", update.Addr) + continue + } + handler.Cancel() + delete(handlers, update.Addr) + + default: + level.Warn(l.logger).Log("msg", "unknown scheduler update operation", "op", update.Op) } - handler.Cancel() - delete(handlers, update.Addr) - - default: - level.Warn(l.logger).Log("msg", "unknown scheduler update operation", "op", update.Op) } } } } -func (l *schedulerLookup) closeWatcher() { - l.closeOnce.Do(func() { l.watcher.Close() }) -} - // parseTCPAddr parses a TCP address string into a [net.TCPAddr]. It doesn't do // any name resolution: the addr must be a numeric pair of IP and port. func parseTCPAddr(addr string) (*net.TCPAddr, error) { @@ -131,3 +123,68 @@ func parseTCPAddr(addr string) (*net.TCPAddr, error) { return net.TCPAddrFromAddrPort(ap), nil } + +type provider interface { + Resolve(ctx context.Context, addrs []string) error + Addresses() []string +} + +type dnsWatcher struct { + addr string + provider provider + + cached map[string]struct{} +} + +func newDNSWatcher(addr string, provider provider) *dnsWatcher { + return &dnsWatcher{ + addr: addr, + provider: provider, + + cached: make(map[string]struct{}), + } +} + +// Poll polls for changes in the DNS records. +func (w *dnsWatcher) Poll(ctx context.Context) ([]*grpcutil.Update, error) { + if err := w.provider.Resolve(ctx, []string{w.addr}); err != nil { + return nil, err + } + + actual := w.discovered() + + var updates []*grpcutil.Update + for addr := range actual { + if _, exists := w.cached[addr]; exists { + continue + } + + w.cached[addr] = struct{}{} + updates = append(updates, &grpcutil.Update{ + Addr: addr, + Op: grpcutil.Add, + }) + } + + for addr := range w.cached { + if _, exists := actual[addr]; !exists { + delete(w.cached, addr) + updates = append(updates, &grpcutil.Update{ + Addr: addr, + Op: grpcutil.Delete, + }) + } + } + + return updates, nil +} + +func (w *dnsWatcher) discovered() map[string]struct{} { + slice := w.provider.Addresses() + + res := make(map[string]struct{}, len(slice)) + for _, addr := range slice { + res[addr] = struct{}{} + } + return res +} diff --git a/pkg/engine/internal/worker/scheduler_lookup_test.go b/pkg/engine/internal/worker/scheduler_lookup_test.go index edd84a6a37..09e47c1bc4 100644 --- a/pkg/engine/internal/worker/scheduler_lookup_test.go +++ b/pkg/engine/internal/worker/scheduler_lookup_test.go @@ -2,90 +2,82 @@ package worker import ( "context" - "errors" - "fmt" "net" "sync" "testing" + "testing/synctest" "time" "github.com/go-kit/log" - "github.com/grafana/dskit/grpcutil" "github.com/stretchr/testify/require" "go.uber.org/atomic" ) func Test_schedulerLookup(t *testing.T) { - var wg sync.WaitGroup - defer wg.Wait() - - fw := &fakeWatcher{ - ctx: t.Context(), - ch: make(chan *grpcutil.Update, 1), - } + // NOTE(rfratto): synctest makes it possible to reliably test asynchronous + // code with time.Sleep. + synctest.Test(t, func(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + // Provide 10 addresses to start with. + addrs := []string{ + "127.0.0.1:8080", "127.0.0.2:8080", "127.0.0.3:8080", "127.0.0.4:8080", "127.0.0.5:8080", + "127.0.0.6:8080", "127.0.0.7:8080", "127.0.0.8:8080", "127.0.0.9:8080", "127.0.0.10:8080", + } - // Manually create a schedulerLookup so we can hook in a custom - // implementation of [grpcutil.Watcher]. - disc := &schedulerLookup{ - logger: log.NewNopLogger(), - watcher: fw, - } + fr := &fakeProvider{ + resolveFunc: func(_ context.Context, _ []string) ([]string, error) { return addrs, nil }, + } - var handlers atomic.Int64 + // Manually create a schedulerLookup so we can hook in a custom + // implementation of [grpcutil.Watcher]. + disc := &schedulerLookup{ + logger: log.NewNopLogger(), + watcher: newDNSWatcher("example.com", fr), + interval: 1 * time.Minute, + } - lookupContext, lookupCancel := context.WithCancel(t.Context()) - defer lookupCancel() + var handlers atomic.Int64 - wg.Add(1) - go func() { - // Decrement the wait group once Run exits. Run won't exit until all - // handlers have terminated, so this validates that logic. - defer wg.Done() + lookupContext, lookupCancel := context.WithCancel(t.Context()) + defer lookupCancel() - _ = disc.Run(lookupContext, func(ctx context.Context, _ net.Addr) { - context.AfterFunc(ctx, func() { handlers.Dec() }) - handlers.Inc() + wg.Go(func() { + _ = disc.Run(lookupContext, func(ctx context.Context, _ net.Addr) { + context.AfterFunc(ctx, func() { handlers.Dec() }) + handlers.Inc() + }) }) - }() - - // Emit 10 schedulers, then wait for there to be one handler per - // scheduler. - for i := range 10 { - addr := fmt.Sprintf("127.0.0.%d:8080", i+1) - fw.ch <- &grpcutil.Update{Op: grpcutil.Add, Addr: addr} - } - - require.Eventually(t, func() bool { - return handlers.Load() == 10 - }, time.Minute, time.Millisecond*10, "should have 10 running handlers, ended with %d", handlers.Load()) - - // Delete all the schedulers, then wait for all handlers to terminate (by - // context). - for i := range 10 { - addr := fmt.Sprintf("127.0.0.%d:8080", i+1) - fw.ch <- &grpcutil.Update{Op: grpcutil.Delete, Addr: addr} - } - require.Eventually(t, func() bool { - return handlers.Load() == 0 - }, time.Minute, time.Millisecond*10, "should have no handlers running, ended with %d", handlers.Load()) + // There should immediately be running handlers without needing to wait + // for the discovery interval. + synctest.Wait() + require.Equal(t, int64(10), handlers.Load(), "should have 10 running handlers") + + // Remove all the addresses from discovery; after the next interval, all + // handlers should be removed. + addrs = addrs[:0] + time.Sleep(disc.interval + time.Second) + require.Equal(t, int64(0), handlers.Load(), "should have no running handlers") + }) } -type fakeWatcher struct { - ctx context.Context - ch chan *grpcutil.Update +type fakeProvider struct { + resolveFunc func(ctx context.Context, addrs []string) ([]string, error) + + cached []string } -func (fw fakeWatcher) Next() ([]*grpcutil.Update, error) { - select { - case <-fw.ctx.Done(): - return nil, fw.ctx.Err() - case update, ok := <-fw.ch: - if !ok { - return nil, errors.New("closed") - } - return []*grpcutil.Update{update}, nil +func (fp *fakeProvider) Resolve(ctx context.Context, addrs []string) error { + resolved, err := fp.resolveFunc(ctx, addrs) + if err != nil { + return err } + fp.cached = resolved + return nil } -func (fw fakeWatcher) Close() { close(fw.ch) } +func (fp *fakeProvider) Addresses() []string { + return fp.cached +}