Fix pubsub pull target (#8281)

pull/8283/head
Travis Patterson 2 years ago committed by GitHub
parent 58e29de988
commit 71979f0f42
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 85
      clients/pkg/promtail/targets/gcplog/pull_target.go
  2. 260
      clients/pkg/promtail/targets/gcplog/pull_target_test.go
  3. 61
      vendor/cloud.google.com/go/internal/testutil/cmp.go
  4. 174
      vendor/cloud.google.com/go/internal/testutil/context.go
  5. 187
      vendor/cloud.google.com/go/internal/testutil/headers_enforcer.go
  6. 44
      vendor/cloud.google.com/go/internal/testutil/rand.go
  7. 116
      vendor/cloud.google.com/go/internal/testutil/retry.go
  8. 135
      vendor/cloud.google.com/go/internal/testutil/server.go
  9. 78
      vendor/cloud.google.com/go/internal/testutil/trace.go
  10. 1493
      vendor/cloud.google.com/go/pubsub/pstest/fake.go
  11. 32
      vendor/google.golang.org/api/impersonate/doc.go
  12. 129
      vendor/google.golang.org/api/impersonate/idtoken.go
  13. 184
      vendor/google.golang.org/api/impersonate/impersonate.go
  14. 169
      vendor/google.golang.org/api/impersonate/user.go
  15. 3
      vendor/modules.txt

@ -1,21 +1,35 @@
package gcplog
import (
"context"
"sync"
"cloud.google.com/go/pubsub"
"context"
"fmt"
"github.com/go-kit/log"
"github.com/go-kit/log/level"
"github.com/grafana/dskit/backoff"
"github.com/prometheus/common/model"
"github.com/prometheus/prometheus/model/relabel"
"google.golang.org/api/option"
"io"
"sync"
"time"
"github.com/grafana/loki/clients/pkg/promtail/api"
"github.com/grafana/loki/clients/pkg/promtail/scrapeconfig"
"github.com/grafana/loki/clients/pkg/promtail/targets/target"
)
var defaultBackoff = backoff.Config{
MinBackoff: 1 * time.Second,
MaxBackoff: 10 * time.Second,
MaxRetries: 5,
}
// pubsubSubscription allows us to mock pubsub for testing
type pubsubSubscription interface {
Receive(ctx context.Context, f func(context.Context, *pubsub.Message)) error
}
// pullTarget represents the target specific to GCP project, with a pull subscription type.
// It collects logs from GCP and push it to Loki.
// nolint:revive
@ -28,12 +42,14 @@ type pullTarget struct {
jobName string
// lifecycle management
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
backoff *backoff.Backoff
// pubsub
ps *pubsub.Client
ps io.Closer
sub pubsubSubscription
msgs chan *pubsub.Message
}
@ -53,9 +69,9 @@ func newPullTarget(
clientOptions ...option.ClientOption,
) (*pullTarget, error) {
ctx, cancel := context.WithCancel(context.Background())
ps, err := pubsub.NewClient(ctx, config.ProjectID, clientOptions...)
if err != nil {
cancel()
return nil, err
}
@ -69,6 +85,8 @@ func newPullTarget(
ctx: ctx,
cancel: cancel,
ps: ps,
sub: ps.SubscriptionInProject(config.Subscription, config.ProjectID),
backoff: backoff.New(ctx, defaultBackoff),
msgs: make(chan *pubsub.Message),
}
@ -83,28 +101,15 @@ func (t *pullTarget) run() error {
t.wg.Add(1)
defer t.wg.Done()
send := t.handler.Chan()
sub := t.ps.SubscriptionInProject(t.config.Subscription, t.config.ProjectID)
go func() {
// NOTE(kavi): `cancel` the context as exiting from this goroutine should stop main `run` loop
// It makesense as no more messages will be received.
defer t.cancel()
err := sub.Receive(t.ctx, func(ctx context.Context, m *pubsub.Message) {
t.msgs <- m
})
if err != nil {
level.Error(t.logger).Log("msg", "failed to receive pubsub messages", "error", err)
t.metrics.gcplogErrors.WithLabelValues(t.config.ProjectID).Inc()
t.metrics.gcplogTargetLastSuccessScrape.WithLabelValues(t.config.ProjectID, t.config.Subscription).SetToCurrentTime()
}
}()
subscriptionErr := make(chan error)
go t.consumeSubscription(subscriptionErr)
for {
select {
case <-t.ctx.Done():
return t.ctx.Err()
case e := <-subscriptionErr:
return e
case m := <-t.msgs:
entry, err := parseGCPLogsEntry(m.Data, t.config.Labels, nil, t.config.UseIncomingTimestamp, t.relabelConfig)
if err != nil {
@ -112,13 +117,41 @@ func (t *pullTarget) run() error {
m.Ack()
break
}
send <- entry
t.handler.Chan() <- entry
m.Ack() // Ack only after log is sent.
t.metrics.gcplogEntries.WithLabelValues(t.config.ProjectID).Inc()
}
}
}
func (t *pullTarget) consumeSubscription(subscriptionErr chan error) {
// NOTE(kavi): `cancel` the context as exiting from this goroutine should stop main `run` loop
// It makesense as no more messages will be received.
defer t.cancel()
var lastError error
for t.backoff.Ongoing() {
lastError = t.sub.Receive(t.ctx, func(ctx context.Context, m *pubsub.Message) {
t.msgs <- m
// When the subscription works properly, it doesn't return
// Reset relevant state here
lastError = nil
t.backoff.Reset()
})
if lastError != nil {
level.Error(t.logger).Log("msg", "failed to receive pubsub messages", "error", lastError)
t.metrics.gcplogErrors.WithLabelValues(t.config.ProjectID).Inc()
t.metrics.gcplogTargetLastSuccessScrape.WithLabelValues(t.config.ProjectID, t.config.Subscription).SetToCurrentTime()
t.backoff.Wait()
}
}
if t.ctx.Err() == nil && t.backoff.Err() != nil {
subscriptionErr <- fmt.Errorf("%w: %s", t.backoff.Err(), lastError.Error())
}
}
func (t *pullTarget) Type() target.TargetType {
return target.GcplogTargetType
}

@ -2,173 +2,174 @@ package gcplog
import (
"context"
"sync"
"github.com/grafana/dskit/backoff"
"github.com/pkg/errors"
"io"
"testing"
"time"
"cloud.google.com/go/pubsub"
"cloud.google.com/go/pubsub/pstest"
"github.com/go-kit/log"
"github.com/grafana/loki/clients/pkg/promtail/client/fake"
"github.com/grafana/loki/clients/pkg/promtail/scrapeconfig"
"github.com/grafana/loki/clients/pkg/promtail/targets/target"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/common/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/grafana/loki/clients/pkg/promtail/api"
"github.com/grafana/loki/clients/pkg/promtail/client/fake"
"github.com/grafana/loki/clients/pkg/promtail/scrapeconfig"
"github.com/grafana/loki/clients/pkg/promtail/targets/target"
)
func TestPullTarget_Run(t *testing.T) {
// Goal: Check message written to pubsub topic is received by the target.
ctx := context.Background()
tt, apiclient, pubsubClient, teardown := testPullTarget(ctx, t)
defer teardown()
// seed pubsub
tp, err := pubsubClient.CreateTopic(ctx, topic)
require.NoError(t, err)
defer tp.Stop()
_, err = pubsubClient.CreateSubscription(ctx, subscription, pubsub.SubscriptionConfig{
Topic: tp,
})
require.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
tt.run() //nolint:errcheck
}()
func TestPullTarget_RunStop(t *testing.T) {
t.Run("it sends messages to the promclient and stopps when Stop() is called", func(t *testing.T) {
tc := testPullTarget(t)
publishMessage(ctx, t, tp)
// Wait till message is received by the run loop.
// NOTE(kavi): sleep is not ideal. but not other way to confirm if api.Handler received messages
time.Sleep(500 * time.Millisecond)
runErr := make(chan error)
go func() {
runErr <- tc.target.run()
}()
err = tt.Stop()
require.NoError(t, err)
tc.sub.messages <- &pubsub.Message{Data: []byte(gcpLogEntry)}
require.Eventually(t, func() bool {
return len(tc.promClient.Received()) > 0
}, time.Second, 50*time.Millisecond)
// wait till `run` stops.
wg.Wait()
// Sleep one more time before reading from api.Received.
time.Sleep(500 * time.Millisecond)
assert.Equal(t, 1, len(apiclient.Received()))
}
func TestPullTarget_Stop(t *testing.T) {
// Goal: To test that `run()` stops when you invoke `target.Stop()`
errs := make(chan error, 1)
require.NoError(t, tc.target.Stop())
require.EqualError(t, <-runErr, "context canceled")
})
ctx := context.Background()
tt, _, _, teardown := testPullTarget(ctx, t)
defer teardown()
t.Run("it retries when there is an error", func(t *testing.T) {
tc := testPullTarget(t)
runErr := make(chan error)
go func() {
runErr <- tc.target.run()
}()
tc.sub.errors <- errors.New("something bad")
tc.sub.messages <- &pubsub.Message{Data: []byte(gcpLogEntry)}
require.Eventually(t, func() bool {
return len(tc.promClient.Received()) > 0
}, time.Second, 50*time.Millisecond)
require.NoError(t, tc.target.Stop())
require.Eventually(t, func() bool {
select {
case e := <-runErr:
return e.Error() == "context canceled"
default:
return false
}
}, time.Second, 50*time.Millisecond)
})
var wg sync.WaitGroup
t.Run("it gives up after MaxRetries of errors", func(t *testing.T) {
tc := testPullTarget(t)
runErr := make(chan error)
go func() {
runErr <- tc.target.run()
}()
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
require.NoError(t, tc.target.Stop())
require.Eventually(t, func() bool {
select {
case e := <-runErr:
return e.Error() == "terminated after 5 retries: something bad"
default:
return false
}
}, time.Second, 50*time.Millisecond)
})
wg.Add(1)
go func() {
defer wg.Done()
errs <- tt.run()
}()
t.Run("a successful message resets retrues", func(t *testing.T) {
tc := testPullTarget(t)
// invoke stop
_ = tt.Stop()
runErr := make(chan error)
go func() {
runErr <- tc.target.run()
}()
// wait till run returns
wg.Wait()
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.messages <- &pubsub.Message{Data: []byte(gcpLogEntry)}
tc.sub.errors <- errors.New("something bad")
tc.sub.errors <- errors.New("something bad")
tc.sub.messages <- &pubsub.Message{Data: []byte(gcpLogEntry)}
// wouldn't block as 1 error is buffered into the channel.
err := <-errs
require.Eventually(t, func() bool {
return len(tc.promClient.Received()) > 1
}, time.Second, 50*time.Millisecond)
// returned error should be cancelled context error
assert.Equal(t, tt.ctx.Err(), err)
require.NoError(t, tc.target.Stop())
})
}
func TestPullTarget_Type(t *testing.T) {
ctx := context.Background()
tt, _, _, teardown := testPullTarget(ctx, t)
defer teardown()
tc := testPullTarget(t)
assert.Equal(t, target.TargetType("Gcplog"), tt.Type())
assert.Equal(t, target.TargetType("Gcplog"), tc.target.Type())
}
func TestPullTarget_Ready(t *testing.T) {
ctx := context.Background()
tt, _, _, teardown := testPullTarget(ctx, t)
defer teardown()
tc := testPullTarget(t)
assert.Equal(t, true, tt.Ready())
assert.Equal(t, true, tc.target.Ready())
}
func TestPullTarget_Labels(t *testing.T) {
ctx := context.Background()
tt, _, _, teardown := testPullTarget(ctx, t)
defer teardown()
tc := testPullTarget(t)
assert.Equal(t, model.LabelSet{"job": "test-gcplogtarget"}, tt.Labels())
assert.Equal(t, model.LabelSet{"job": "test-gcplogtarget"}, tc.target.Labels())
}
func testPullTarget(ctx context.Context, t *testing.T) (*pullTarget, *fake.Client, *pubsub.Client, func()) {
t.Helper()
ctx, cancel := context.WithCancel(ctx)
mockSvr := pstest.NewServer()
conn, err := grpc.Dial(mockSvr.Addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
mockpubsubClient, err := pubsub.NewClient(ctx, testConfig.ProjectID, option.WithGRPCConn(conn))
require.NoError(t, err)
type testContext struct {
target *pullTarget
promClient *fake.Client
sub *fakeSubscription
}
fakeClient := fake.New(func() {})
func testPullTarget(t *testing.T) *testContext {
t.Helper()
var handler api.EntryHandler = fakeClient
ctx, cancel := context.WithCancel(context.Background())
sub := newFakeSubscription()
promClient := fake.New(func() {})
target := &pullTarget{
metrics: NewMetrics(prometheus.NewRegistry()),
logger: log.NewNopLogger(),
handler: handler,
handler: promClient,
relabelConfig: nil,
config: testConfig,
jobName: t.Name() + "job-test-gcplogtarget",
ctx: ctx,
cancel: cancel,
ps: mockpubsubClient,
config: testConfig,
jobName: t.Name() + "job-test-gcplogtarget",
ps: io.NopCloser(nil),
sub: sub,
msgs: make(chan *pubsub.Message),
backoff: backoff.New(ctx, testBackoff),
}
// cleanup
return target, fakeClient, mockpubsubClient, func() {
cancel()
conn.Close()
mockSvr.Close()
mockpubsubClient.Close()
return &testContext{
target: target,
promClient: promClient,
sub: sub,
}
}
func publishMessage(ctx context.Context, t *testing.T, topic *pubsub.Topic) {
t.Helper()
res := topic.Publish(ctx, &pubsub.Message{Data: []byte(gcpLogEntry)})
_, err := res.Get(ctx) // wait till message is actully published
require.NoError(t, err)
}
const (
project = "test-project"
topic = "test-topic"
subscription = "test-subscription"
gcpLogEntry = `
gcpLogEntry = `
{
"insertId": "ajv4d1f1ch8dr",
"logName": "projects/grafanalabs-dev/logs/cloudaudit.googleapis.com%2Fdata_access",
@ -245,3 +246,32 @@ var testConfig = &scrapeconfig.GcplogTargetConfig{
},
SubscriptionType: "pull",
}
func newFakeSubscription() *fakeSubscription {
return &fakeSubscription{
messages: make(chan *pubsub.Message),
errors: make(chan error),
}
}
type fakeSubscription struct {
messages chan *pubsub.Message
errors chan error
}
func (s *fakeSubscription) Receive(ctx context.Context, f func(context.Context, *pubsub.Message)) error {
for {
select {
case m := <-s.messages:
f(ctx, m)
case e := <-s.errors:
return e
}
}
}
var testBackoff = backoff.Config{
MinBackoff: 1 * time.Millisecond,
MaxBackoff: 10 * time.Millisecond,
MaxRetries: 5,
}

@ -1,61 +0,0 @@
// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"math"
"math/big"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
)
var (
alwaysEqual = cmp.Comparer(func(_, _ interface{}) bool { return true })
defaultCmpOptions = []cmp.Option{
// Use proto.Equal for protobufs
cmp.Comparer(proto.Equal),
// Use big.Rat.Cmp for big.Rats
cmp.Comparer(func(x, y *big.Rat) bool {
if x == nil || y == nil {
return x == y
}
return x.Cmp(y) == 0
}),
// NaNs compare equal
cmp.FilterValues(func(x, y float64) bool {
return math.IsNaN(x) && math.IsNaN(y)
}, alwaysEqual),
cmp.FilterValues(func(x, y float32) bool {
return math.IsNaN(float64(x)) && math.IsNaN(float64(y))
}, alwaysEqual),
}
)
// Equal tests two values for equality.
func Equal(x, y interface{}, opts ...cmp.Option) bool {
// Put default options at the end. Order doesn't matter.
opts = append(opts[:len(opts):len(opts)], defaultCmpOptions...)
return cmp.Equal(x, y, opts...)
}
// Diff reports the differences between two values.
// Diff(x, y) == "" iff Equal(x, y).
func Diff(x, y interface{}, opts ...cmp.Option) string {
// Put default options at the end. Order doesn't matter.
opts = append(opts[:len(opts):len(opts)], defaultCmpOptions...)
return cmp.Diff(x, y, opts...)
}

@ -1,174 +0,0 @@
// Copyright 2014 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package testutil contains helper functions for writing tests.
package testutil
import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"golang.org/x/oauth2/jwt"
"google.golang.org/api/impersonate"
)
const (
envProjID = "GCLOUD_TESTS_GOLANG_PROJECT_ID"
envPrivateKey = "GCLOUD_TESTS_GOLANG_KEY"
envImpersonate = "GCLOUD_TESTS_IMPERSONATE_CREDENTIALS"
)
// ProjID returns the project ID to use in integration tests, or the empty
// string if none is configured.
func ProjID() string {
return os.Getenv(envProjID)
}
// Credentials returns the credentials to use in integration tests, or nil if
// none is configured. It uses the standard environment variable for tests in
// this repo.
func Credentials(ctx context.Context, scopes ...string) *google.Credentials {
return CredentialsEnv(ctx, envPrivateKey, scopes...)
}
// CredentialsEnv returns the credentials to use in integration tests, or nil
// if none is configured. If the environment variable is unset, CredentialsEnv
// will try to find 'Application Default Credentials'. Else, CredentialsEnv
// will return nil. CredentialsEnv will log.Fatal if the token source is
// specified but missing or invalid.
func CredentialsEnv(ctx context.Context, envVar string, scopes ...string) *google.Credentials {
if impKey := os.Getenv(envImpersonate); impKey == "true" {
return &google.Credentials{
TokenSource: impersonatedTokenSource(ctx, scopes),
ProjectID: "dulcet-port-762",
}
}
key := os.Getenv(envVar)
if key == "" { // Try for application default credentials.
creds, err := google.FindDefaultCredentials(ctx, scopes...)
if err != nil {
log.Println("No 'Application Default Credentials' found.")
return nil
}
return creds
}
data, err := ioutil.ReadFile(key)
if err != nil {
log.Fatal(err)
}
creds, err := google.CredentialsFromJSON(ctx, data, scopes...)
if err != nil {
log.Fatal(err)
}
return creds
}
// TokenSource returns the OAuth2 token source to use in integration tests,
// or nil if none is configured. It uses the standard environment variable
// for tests in this repo.
func TokenSource(ctx context.Context, scopes ...string) oauth2.TokenSource {
return TokenSourceEnv(ctx, envPrivateKey, scopes...)
}
// TokenSourceEnv returns the OAuth2 token source to use in integration tests. or nil
// if none is configured. It tries to get credentials from the filename in the
// environment variable envVar. If the environment variable is unset, TokenSourceEnv
// will try to find 'Application Default Credentials'. Else, TokenSourceEnv will
// return nil. TokenSourceEnv will log.Fatal if the token source is specified but
// missing or invalid.
func TokenSourceEnv(ctx context.Context, envVar string, scopes ...string) oauth2.TokenSource {
if impKey := os.Getenv(envImpersonate); impKey == "true" {
return impersonatedTokenSource(ctx, scopes)
}
key := os.Getenv(envVar)
if key == "" { // Try for application default credentials.
ts, err := google.DefaultTokenSource(ctx, scopes...)
if err != nil {
log.Println("No 'Application Default Credentials' found.")
return nil
}
return ts
}
conf, err := jwtConfigFromFile(key, scopes)
if err != nil {
log.Fatal(err)
}
return conf.TokenSource(ctx)
}
func impersonatedTokenSource(ctx context.Context, scopes []string) oauth2.TokenSource {
ts, err := impersonate.CredentialsTokenSource(ctx, impersonate.CredentialsConfig{
TargetPrincipal: "kokoro@dulcet-port-762.iam.gserviceaccount.com",
Scopes: scopes,
})
if err != nil {
log.Fatalf("Unable to impersonate credentials, exiting: %v", err)
}
return ts
}
// JWTConfig reads the JSON private key file whose name is in the default
// environment variable, and returns the jwt.Config it contains. It ignores
// scopes.
// If the environment variable is empty, it returns (nil, nil).
func JWTConfig() (*jwt.Config, error) {
return jwtConfigFromFile(os.Getenv(envPrivateKey), nil)
}
// jwtConfigFromFile reads the given JSON private key file, and returns the
// jwt.Config it contains.
// If the filename is empty, it returns (nil, nil).
func jwtConfigFromFile(filename string, scopes []string) (*jwt.Config, error) {
if filename == "" {
return nil, nil
}
jsonKey, err := ioutil.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("cannot read the JSON key file, err: %v", err)
}
conf, err := google.JWTConfigFromJSON(jsonKey, scopes...)
if err != nil {
return nil, fmt.Errorf("google.JWTConfigFromJSON: %v", err)
}
return conf, nil
}
// CanReplay reports whether an integration test can be run in replay mode.
// The replay file must exist, and the GCLOUD_TESTS_GOLANG_ENABLE_REPLAY
// environment variable must be non-empty.
func CanReplay(replayFilename string) bool {
if os.Getenv("GCLOUD_TESTS_GOLANG_ENABLE_REPLAY") == "" {
return false
}
_, err := os.Stat(replayFilename)
return err == nil
}
// ErroringTokenSource is a token source for testing purposes,
// to always return a non-nil error to its caller. It is useful
// when testing error responses with bad oauth2 credentials.
type ErroringTokenSource struct{}
// Token implements oauth2.TokenSource, returning a nil oauth2.Token and a non-nil error.
func (fts ErroringTokenSource) Token() (*oauth2.Token, error) {
return nil, errors.New("intentional error")
}

@ -1,187 +0,0 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"bytes"
"context"
"errors"
"fmt"
"log"
"os"
"strings"
"google.golang.org/api/option"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
)
// HeaderChecker defines header checking and validation rules for any outgoing metadata.
type HeaderChecker struct {
// Key is the header name to be checked against e.g. "x-goog-api-client".
Key string
// ValuesValidator validates the header values retrieved from mapping against
// Key in the Headers.
ValuesValidator func(values ...string) error
}
// HeadersEnforcer asserts that outgoing RPC headers
// are present and match expectations. If the expected headers
// are not present or don't match expectations, it'll invoke OnFailure
// with the validation error, or instead log.Fatal if OnFailure is nil.
//
// It expects that every declared key will be present in the outgoing
// RPC header and each value will be validated by the validation function.
type HeadersEnforcer struct {
// Checkers maps header keys that are expected to be sent in the metadata
// of outgoing gRPC requests, against the values passed into the custom
// validation functions.
//
// If Checkers is nil or empty, only the default header "x-goog-api-client"
// will be checked for.
// Otherwise, if you supply Matchers, those keys and their respective
// validation functions will be checked.
Checkers []*HeaderChecker
// OnFailure is the function that will be invoked after all validation
// failures have been composed. If OnFailure is nil, log.Fatal will be
// invoked instead.
OnFailure func(fmt_ string, args ...interface{})
}
// StreamInterceptors returns a list of StreamClientInterceptor functions which
// enforce the presence and validity of expected headers during streaming RPCs.
//
// For client implementations which provide their own StreamClientInterceptor(s)
// these interceptors should be specified as the final elements to
// WithChainStreamInterceptor.
//
// Alternatively, users may apply gPRC options produced from DialOptions to
// apply all applicable gRPC interceptors.
func (h *HeadersEnforcer) StreamInterceptors() []grpc.StreamClientInterceptor {
return []grpc.StreamClientInterceptor{h.interceptStream}
}
// UnaryInterceptors returns a list of UnaryClientInterceptor functions which
// enforce the presence and validity of expected headers during unary RPCs.
//
// For client implementations which provide their own UnaryClientInterceptor(s)
// these interceptors should be specified as the final elements to
// WithChainUnaryInterceptor.
//
// Alternatively, users may apply gPRC options produced from DialOptions to
// apply all applicable gRPC interceptors.
func (h *HeadersEnforcer) UnaryInterceptors() []grpc.UnaryClientInterceptor {
return []grpc.UnaryClientInterceptor{h.interceptUnary}
}
// DialOptions returns gRPC DialOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) DialOptions() []grpc.DialOption {
return []grpc.DialOption{
grpc.WithChainStreamInterceptor(h.interceptStream),
grpc.WithChainUnaryInterceptor(h.interceptUnary),
}
}
// CallOptions returns ClientOptions consisting of unary and stream interceptors
// to enforce the presence and validity of expected headers.
func (h *HeadersEnforcer) CallOptions() (copts []option.ClientOption) {
dopts := h.DialOptions()
for _, dopt := range dopts {
copts = append(copts, option.WithGRPCDialOption(dopt))
}
return
}
func (h *HeadersEnforcer) interceptUnary(ctx context.Context, method string, req, res interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
h.checkMetadata(ctx, method)
return invoker(ctx, method, req, res, cc, opts...)
}
func (h *HeadersEnforcer) interceptStream(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
h.checkMetadata(ctx, method)
return streamer(ctx, desc, cc, method, opts...)
}
// XGoogClientHeaderChecker is a HeaderChecker that ensures that the "x-goog-api-client"
// header is present on outgoing metadata.
var XGoogClientHeaderChecker = &HeaderChecker{
Key: "x-goog-api-client",
ValuesValidator: func(values ...string) error {
if len(values) == 0 {
return errors.New("expecting values")
}
for _, value := range values {
switch {
case strings.Contains(value, "gl-go/"):
// TODO: check for exact version strings.
return nil
default: // Add others here.
}
}
return errors.New("unmatched values")
},
}
// DefaultHeadersEnforcer returns a HeadersEnforcer that at bare minimum checks that
// the "x-goog-api-client" key is present in the outgoing metadata headers. On any
// validation failure, it will invoke log.Fatalf with the error message.
func DefaultHeadersEnforcer() *HeadersEnforcer {
return &HeadersEnforcer{
Checkers: []*HeaderChecker{XGoogClientHeaderChecker},
}
}
func (h *HeadersEnforcer) checkMetadata(ctx context.Context, method string) {
onFailure := h.OnFailure
if onFailure == nil {
lgr := log.New(os.Stderr, "", 0) // Do not log the time prefix, it is noisy in test failure logs.
onFailure = func(fmt_ string, args ...interface{}) {
lgr.Fatalf(fmt_, args...)
}
}
md, ok := metadata.FromOutgoingContext(ctx)
if !ok {
onFailure("Missing metadata for method %q", method)
return
}
checkers := h.Checkers
if len(checkers) == 0 {
// Instead use the default HeaderChecker.
checkers = append(checkers, XGoogClientHeaderChecker)
}
errBuf := new(bytes.Buffer)
for _, checker := range checkers {
hdrKey := checker.Key
outHdrValues, ok := md[hdrKey]
if !ok {
fmt.Fprintf(errBuf, "missing header %q\n", hdrKey)
continue
}
if err := checker.ValuesValidator(outHdrValues...); err != nil {
fmt.Fprintf(errBuf, "header %q: %v\n", hdrKey, err)
}
}
if errBuf.Len() != 0 {
onFailure("For method %q, errors:\n%s", method, errBuf)
return
}
}

@ -1,44 +0,0 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"math/rand"
"sync"
"time"
)
// NewRand creates a new *rand.Rand seeded with t. The return value is safe for use
// with multiple goroutines.
func NewRand(t time.Time) *rand.Rand {
s := &lockedSource{src: rand.NewSource(t.UnixNano())}
return rand.New(s)
}
// lockedSource makes a rand.Source safe for use by multiple goroutines.
type lockedSource struct {
mu sync.Mutex
src rand.Source
}
func (ls *lockedSource) Int63() int64 {
ls.mu.Lock()
defer ls.mu.Unlock()
return ls.src.Int63()
}
func (ls *lockedSource) Seed(int64) {
panic("shouldn't be calling Seed")
}

@ -1,116 +0,0 @@
// Copyright 2019 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"bytes"
"fmt"
"path/filepath"
"runtime"
"strconv"
"testing"
"time"
)
// Retry runs function f for up to maxAttempts times until f returns successfully, and reports whether f was run successfully.
// It will sleep for the given period between invocations of f.
// Use the provided *testutil.R instead of a *testing.T from the function.
func Retry(t *testing.T, maxAttempts int, sleep time.Duration, f func(r *R)) bool {
for attempt := 1; attempt <= maxAttempts; attempt++ {
r := &R{Attempt: attempt, log: &bytes.Buffer{}}
f(r)
if !r.failed {
if r.log.Len() != 0 {
t.Logf("Success after %d attempts:%s", attempt, r.log.String())
}
return true
}
if attempt == maxAttempts {
t.Logf("FAILED after %d attempts:%s", attempt, r.log.String())
t.Fail()
}
time.Sleep(sleep)
}
return false
}
// RetryWithoutTest is a variant of Retry that does not use a testing parameter.
// It is meant for testing utilities that do not pass around the testing context, such as cloudrunci.
func RetryWithoutTest(maxAttempts int, sleep time.Duration, f func(r *R)) bool {
for attempt := 1; attempt <= maxAttempts; attempt++ {
r := &R{Attempt: attempt, log: &bytes.Buffer{}}
f(r)
if !r.failed {
if r.log.Len() != 0 {
r.Logf("Success after %d attempts:%s", attempt, r.log.String())
}
return true
}
if attempt == maxAttempts {
r.Logf("FAILED after %d attempts:%s", attempt, r.log.String())
return false
}
time.Sleep(sleep)
}
return false
}
// R is passed to each run of a flaky test run, manages state and accumulates log statements.
type R struct {
// The number of current attempt.
Attempt int
failed bool
log *bytes.Buffer
}
// Fail marks the run as failed, and will retry once the function returns.
func (r *R) Fail() {
r.failed = true
}
// Errorf is equivalent to Logf followed by Fail.
func (r *R) Errorf(s string, v ...interface{}) {
r.logf(s, v...)
r.Fail()
}
// Logf formats its arguments and records it in the error log.
// The text is only printed for the final unsuccessful run or the first successful run.
func (r *R) Logf(s string, v ...interface{}) {
r.logf(s, v...)
}
func (r *R) logf(s string, v ...interface{}) {
fmt.Fprint(r.log, "\n")
fmt.Fprint(r.log, lineNumber())
fmt.Fprintf(r.log, s, v...)
}
func lineNumber() string {
_, file, line, ok := runtime.Caller(3) // logf, public func, user function
if !ok {
return ""
}
return filepath.Base(file) + ":" + strconv.Itoa(line) + ": "
}

@ -1,135 +0,0 @@
/*
Copyright 2016 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package testutil
import (
"fmt"
"log"
"net"
"regexp"
"strconv"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// A Server is an in-process gRPC server, listening on a system-chosen port on
// the local loopback interface. Servers are for testing only and are not
// intended to be used in production code.
//
// To create a server, make a new Server, register your handlers, then call
// Start:
//
// srv, err := NewServer()
// ...
// mypb.RegisterMyServiceServer(srv.Gsrv, &myHandler)
// ....
// srv.Start()
//
// Clients should connect to the server with no security:
//
// conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
// ...
type Server struct {
Addr string
Port int
l net.Listener
Gsrv *grpc.Server
}
// NewServer creates a new Server. The Server will be listening for gRPC connections
// at the address named by the Addr field, without TLS.
func NewServer(opts ...grpc.ServerOption) (*Server, error) {
return NewServerWithPort(0, opts...)
}
// NewServerWithPort creates a new Server at a specific port. The Server will be listening
// for gRPC connections at the address named by the Addr field, without TLS.
func NewServerWithPort(port int, opts ...grpc.ServerOption) (*Server, error) {
l, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
if err != nil {
return nil, err
}
s := &Server{
Addr: l.Addr().String(),
Port: parsePort(l.Addr().String()),
l: l,
Gsrv: grpc.NewServer(opts...),
}
return s, nil
}
// Start causes the server to start accepting incoming connections.
// Call Start after registering handlers.
func (s *Server) Start() {
go func() {
if err := s.Gsrv.Serve(s.l); err != nil {
log.Printf("testutil.Server.Start: %v", err)
}
}()
}
// Close shuts down the server.
func (s *Server) Close() {
s.Gsrv.Stop()
s.l.Close()
}
// PageBounds converts an incoming page size and token from an RPC request into
// slice bounds and the outgoing next-page token.
//
// PageBounds assumes that the complete, unpaginated list of items exists as a
// single slice. In addition to the page size and token, PageBounds needs the
// length of that slice.
//
// PageBounds's first two return values should be used to construct a sub-slice of
// the complete, unpaginated slice. E.g. if the complete slice is s, then
// s[from:to] is the desired page. Its third return value should be set as the
// NextPageToken field of the RPC response.
func PageBounds(pageSize int, pageToken string, length int) (from, to int, nextPageToken string, err error) {
from, to = 0, length
if pageToken != "" {
from, err = strconv.Atoi(pageToken)
if err != nil {
return 0, 0, "", status.Errorf(codes.InvalidArgument, "bad page token: %v", err)
}
if from >= length {
return length, length, "", nil
}
}
if pageSize > 0 && from+pageSize < length {
to = from + pageSize
nextPageToken = strconv.Itoa(to)
}
return from, to, nextPageToken, nil
}
var portParser = regexp.MustCompile(`:[0-9]+`)
func parsePort(addr string) int {
res := portParser.FindAllString(addr, -1)
if len(res) == 0 {
panic(fmt.Errorf("parsePort: found no numbers in %s", addr))
}
stringPort := res[0][1:] // strip the :
p, err := strconv.ParseInt(stringPort, 10, 32)
if err != nil {
panic(err)
}
return int(p)
}

@ -1,78 +0,0 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"log"
"sync"
"time"
"go.opencensus.io/plugin/ocgrpc"
"go.opencensus.io/stats/view"
"go.opencensus.io/trace"
)
// TestExporter is a test utility exporter. It should be created with NewtestExporter.
type TestExporter struct {
mu sync.Mutex
Spans []*trace.SpanData
Stats chan *view.Data
Views []*view.View
}
// NewTestExporter creates a TestExporter and registers it with OpenCensus.
func NewTestExporter(views ...*view.View) *TestExporter {
if len(views) == 0 {
views = ocgrpc.DefaultClientViews
}
te := &TestExporter{Stats: make(chan *view.Data), Views: views}
view.RegisterExporter(te)
view.SetReportingPeriod(time.Millisecond)
if err := view.Register(views...); err != nil {
log.Fatal(err)
}
trace.RegisterExporter(te)
trace.ApplyConfig(trace.Config{DefaultSampler: trace.AlwaysSample()})
return te
}
// ExportSpan exports a span.
func (te *TestExporter) ExportSpan(s *trace.SpanData) {
te.mu.Lock()
defer te.mu.Unlock()
te.Spans = append(te.Spans, s)
}
// ExportView exports a view.
func (te *TestExporter) ExportView(vd *view.Data) {
if len(vd.Rows) > 0 {
select {
case te.Stats <- vd:
default:
}
}
}
// Unregister unregisters the exporter from OpenCensus.
func (te *TestExporter) Unregister() {
view.Unregister(te.Views...)
view.UnregisterExporter(te)
trace.UnregisterExporter(te)
view.SetReportingPeriod(0) // reset to default value
}

File diff suppressed because it is too large Load Diff

@ -1,32 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package impersonate is used to impersonate Google Credentials.
//
// # Required IAM roles
//
// In order to impersonate a service account the base service account must have
// the Service Account Token Creator role, roles/iam.serviceAccountTokenCreator,
// on the service account being impersonated. See
// https://cloud.google.com/iam/docs/understanding-service-accounts.
//
// Optionally, delegates can be used during impersonation if the base service
// account lacks the token creator role on the target. When using delegates,
// each service account must be granted roles/iam.serviceAccountTokenCreator
// on the next service account in the delgation chain.
//
// For example, if a base service account of SA1 is trying to impersonate target
// service account SA2 while using delegate service accounts DSA1 and DSA2,
// the following must be true:
//
// 1. Base service account SA1 has roles/iam.serviceAccountTokenCreator on
// DSA1.
// 2. DSA1 has roles/iam.serviceAccountTokenCreator on DSA2.
// 3. DSA2 has roles/iam.serviceAccountTokenCreator on target SA2.
//
// If the base credential is an authorized user and not a service account, or if
// the option WithQuotaProject is set, the target service account must have a
// role that grants the serviceusage.services.use permission such as
// roles/serviceusage.serviceUsageConsumer.
package impersonate

@ -1,129 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impersonate
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"time"
"golang.org/x/oauth2"
"google.golang.org/api/option"
htransport "google.golang.org/api/transport/http"
)
// IDTokenConfig for generating an impersonated ID token.
type IDTokenConfig struct {
// Audience is the `aud` field for the token, such as an API endpoint the
// token will grant access to. Required.
Audience string
// TargetPrincipal is the email address of the service account to
// impersonate. Required.
TargetPrincipal string
// IncludeEmail includes the service account's email in the token. The
// resulting token will include both an `email` and `email_verified`
// claim.
IncludeEmail bool
// Delegates are the service account email addresses in a delegation chain.
// Each service account must be granted roles/iam.serviceAccountTokenCreator
// on the next service account in the chain. Optional.
Delegates []string
}
// IDTokenSource creates an impersonated TokenSource that returns ID tokens
// configured with the provided config and using credentials loaded from
// Application Default Credentials as the base credentials. The tokens provided
// by the source are valid for one hour and are automatically refreshed.
func IDTokenSource(ctx context.Context, config IDTokenConfig, opts ...option.ClientOption) (oauth2.TokenSource, error) {
if config.Audience == "" {
return nil, fmt.Errorf("impersonate: an audience must be provided")
}
if config.TargetPrincipal == "" {
return nil, fmt.Errorf("impersonate: a target service account must be provided")
}
clientOpts := append(defaultClientOptions(), opts...)
client, _, err := htransport.NewClient(ctx, clientOpts...)
if err != nil {
return nil, err
}
its := impersonatedIDTokenSource{
client: client,
targetPrincipal: config.TargetPrincipal,
audience: config.Audience,
includeEmail: config.IncludeEmail,
}
for _, v := range config.Delegates {
its.delegates = append(its.delegates, formatIAMServiceAccountName(v))
}
return oauth2.ReuseTokenSource(nil, its), nil
}
type generateIDTokenRequest struct {
Audience string `json:"audience"`
IncludeEmail bool `json:"includeEmail"`
Delegates []string `json:"delegates,omitempty"`
}
type generateIDTokenResponse struct {
Token string `json:"token"`
}
type impersonatedIDTokenSource struct {
client *http.Client
targetPrincipal string
audience string
includeEmail bool
delegates []string
}
func (i impersonatedIDTokenSource) Token() (*oauth2.Token, error) {
now := time.Now()
genIDTokenReq := generateIDTokenRequest{
Audience: i.audience,
IncludeEmail: i.includeEmail,
Delegates: i.delegates,
}
bodyBytes, err := json.Marshal(genIDTokenReq)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to marshal request: %v", err)
}
url := fmt.Sprintf("%s/v1/%s:generateIdToken", iamCredentailsEndpoint, formatIAMServiceAccountName(i.targetPrincipal))
req, err := http.NewRequest("POST", url, bytes.NewReader(bodyBytes))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := i.client.Do(req)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to generate ID token: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %v", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}
var generateIDTokenResp generateIDTokenResponse
if err := json.Unmarshal(body, &generateIDTokenResp); err != nil {
return nil, fmt.Errorf("impersonate: unable to parse response: %v", err)
}
return &oauth2.Token{
AccessToken: generateIDTokenResp.Token,
// Generated ID tokens are good for one hour.
Expiry: now.Add(1 * time.Hour),
}, nil
}

@ -1,184 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impersonate
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"time"
"golang.org/x/oauth2"
"google.golang.org/api/option"
"google.golang.org/api/option/internaloption"
htransport "google.golang.org/api/transport/http"
)
var (
iamCredentailsEndpoint = "https://iamcredentials.googleapis.com"
oauth2Endpoint = "https://oauth2.googleapis.com"
)
// CredentialsConfig for generating impersonated credentials.
type CredentialsConfig struct {
// TargetPrincipal is the email address of the service account to
// impersonate. Required.
TargetPrincipal string
// Scopes that the impersonated credential should have. Required.
Scopes []string
// Delegates are the service account email addresses in a delegation chain.
// Each service account must be granted roles/iam.serviceAccountTokenCreator
// on the next service account in the chain. Optional.
Delegates []string
// Lifetime is the amount of time until the impersonated token expires. If
// unset the token's lifetime will be one hour and be automatically
// refreshed. If set the token may have a max lifetime of one hour and will
// not be refreshed. Service accounts that have been added to an org policy
// with constraints/iam.allowServiceAccountCredentialLifetimeExtension may
// request a token lifetime of up to 12 hours. Optional.
Lifetime time.Duration
// Subject is the sub field of a JWT. This field should only be set if you
// wish to impersonate as a user. This feature is useful when using domain
// wide delegation. Optional.
Subject string
}
// defaultClientOptions ensures the base credentials will work with the IAM
// Credentials API if no scope or audience is set by the user.
func defaultClientOptions() []option.ClientOption {
return []option.ClientOption{
internaloption.WithDefaultAudience("https://iamcredentials.googleapis.com/"),
internaloption.WithDefaultScopes("https://www.googleapis.com/auth/cloud-platform"),
}
}
// CredentialsTokenSource returns an impersonated CredentialsTokenSource configured with the provided
// config and using credentials loaded from Application Default Credentials as
// the base credentials.
func CredentialsTokenSource(ctx context.Context, config CredentialsConfig, opts ...option.ClientOption) (oauth2.TokenSource, error) {
if config.TargetPrincipal == "" {
return nil, fmt.Errorf("impersonate: a target service account must be provided")
}
if len(config.Scopes) == 0 {
return nil, fmt.Errorf("impersonate: scopes must be provided")
}
if config.Lifetime.Hours() > 12 {
return nil, fmt.Errorf("impersonate: max lifetime is 12 hours")
}
var isStaticToken bool
// Default to the longest acceptable value of one hour as the token will
// be refreshed automatically if not set.
lifetime := 3600 * time.Second
if config.Lifetime != 0 {
lifetime = config.Lifetime
// Don't auto-refresh token if a lifetime is configured.
isStaticToken = true
}
clientOpts := append(defaultClientOptions(), opts...)
client, _, err := htransport.NewClient(ctx, clientOpts...)
if err != nil {
return nil, err
}
// If a subject is specified a different auth-flow is initiated to
// impersonate as the provided subject (user).
if config.Subject != "" {
return user(ctx, config, client, lifetime, isStaticToken)
}
its := impersonatedTokenSource{
client: client,
targetPrincipal: config.TargetPrincipal,
lifetime: fmt.Sprintf("%.fs", lifetime.Seconds()),
}
for _, v := range config.Delegates {
its.delegates = append(its.delegates, formatIAMServiceAccountName(v))
}
its.scopes = make([]string, len(config.Scopes))
copy(its.scopes, config.Scopes)
if isStaticToken {
tok, err := its.Token()
if err != nil {
return nil, err
}
return oauth2.StaticTokenSource(tok), nil
}
return oauth2.ReuseTokenSource(nil, its), nil
}
func formatIAMServiceAccountName(name string) string {
return fmt.Sprintf("projects/-/serviceAccounts/%s", name)
}
type generateAccessTokenReq struct {
Delegates []string `json:"delegates,omitempty"`
Lifetime string `json:"lifetime,omitempty"`
Scope []string `json:"scope,omitempty"`
}
type generateAccessTokenResp struct {
AccessToken string `json:"accessToken"`
ExpireTime string `json:"expireTime"`
}
type impersonatedTokenSource struct {
client *http.Client
targetPrincipal string
lifetime string
scopes []string
delegates []string
}
// Token returns an impersonated Token.
func (i impersonatedTokenSource) Token() (*oauth2.Token, error) {
reqBody := generateAccessTokenReq{
Delegates: i.delegates,
Lifetime: i.lifetime,
Scope: i.scopes,
}
b, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to marshal request: %v", err)
}
url := fmt.Sprintf("%s/v1/%s:generateAccessToken", iamCredentailsEndpoint, formatIAMServiceAccountName(i.targetPrincipal))
req, err := http.NewRequest("POST", url, bytes.NewReader(b))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := i.client.Do(req)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to generate access token: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %v", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}
var accessTokenResp generateAccessTokenResp
if err := json.Unmarshal(body, &accessTokenResp); err != nil {
return nil, fmt.Errorf("impersonate: unable to parse response: %v", err)
}
expiry, err := time.Parse(time.RFC3339, accessTokenResp.ExpireTime)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to parse expiry: %v", err)
}
return &oauth2.Token{
AccessToken: accessTokenResp.AccessToken,
Expiry: expiry,
}, nil
}

@ -1,169 +0,0 @@
// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package impersonate
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"golang.org/x/oauth2"
)
func user(ctx context.Context, c CredentialsConfig, client *http.Client, lifetime time.Duration, isStaticToken bool) (oauth2.TokenSource, error) {
u := userTokenSource{
client: client,
targetPrincipal: c.TargetPrincipal,
subject: c.Subject,
lifetime: lifetime,
}
u.delegates = make([]string, len(c.Delegates))
for i, v := range c.Delegates {
u.delegates[i] = formatIAMServiceAccountName(v)
}
u.scopes = make([]string, len(c.Scopes))
copy(u.scopes, c.Scopes)
if isStaticToken {
tok, err := u.Token()
if err != nil {
return nil, err
}
return oauth2.StaticTokenSource(tok), nil
}
return oauth2.ReuseTokenSource(nil, u), nil
}
type claimSet struct {
Iss string `json:"iss"`
Scope string `json:"scope,omitempty"`
Sub string `json:"sub,omitempty"`
Aud string `json:"aud"`
Iat int64 `json:"iat"`
Exp int64 `json:"exp"`
}
type signJWTRequest struct {
Payload string `json:"payload"`
Delegates []string `json:"delegates,omitempty"`
}
type signJWTResponse struct {
// KeyID is the key used to sign the JWT.
KeyID string `json:"keyId"`
// SignedJwt contains the automatically generated header; the
// client-supplied payload; and the signature, which is generated using
// the key referenced by the `kid` field in the header.
SignedJWT string `json:"signedJwt"`
}
type exchangeTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
type userTokenSource struct {
client *http.Client
targetPrincipal string
subject string
scopes []string
lifetime time.Duration
delegates []string
}
func (u userTokenSource) Token() (*oauth2.Token, error) {
signedJWT, err := u.signJWT()
if err != nil {
return nil, err
}
return u.exchangeToken(signedJWT)
}
func (u userTokenSource) signJWT() (string, error) {
now := time.Now()
exp := now.Add(u.lifetime)
claims := claimSet{
Iss: u.targetPrincipal,
Scope: strings.Join(u.scopes, " "),
Sub: u.subject,
Aud: fmt.Sprintf("%s/token", oauth2Endpoint),
Iat: now.Unix(),
Exp: exp.Unix(),
}
payloadBytes, err := json.Marshal(claims)
if err != nil {
return "", fmt.Errorf("impersonate: unable to marshal claims: %v", err)
}
signJWTReq := signJWTRequest{
Payload: string(payloadBytes),
Delegates: u.delegates,
}
bodyBytes, err := json.Marshal(signJWTReq)
if err != nil {
return "", fmt.Errorf("impersonate: unable to marshal request: %v", err)
}
reqURL := fmt.Sprintf("%s/v1/%s:signJwt", iamCredentailsEndpoint, formatIAMServiceAccountName(u.targetPrincipal))
req, err := http.NewRequest("POST", reqURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", fmt.Errorf("impersonate: unable to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json")
rawResp, err := u.client.Do(req)
if err != nil {
return "", fmt.Errorf("impersonate: unable to sign JWT: %v", err)
}
body, err := ioutil.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
if err != nil {
return "", fmt.Errorf("impersonate: unable to read body: %v", err)
}
if c := rawResp.StatusCode; c < 200 || c > 299 {
return "", fmt.Errorf("impersonate: status code %d: %s", c, body)
}
var signJWTResp signJWTResponse
if err := json.Unmarshal(body, &signJWTResp); err != nil {
return "", fmt.Errorf("impersonate: unable to parse response: %v", err)
}
return signJWTResp.SignedJWT, nil
}
func (u userTokenSource) exchangeToken(signedJWT string) (*oauth2.Token, error) {
now := time.Now()
v := url.Values{}
v.Set("grant_type", "assertion")
v.Set("assertion_type", "http://oauth.net/grant_type/jwt/1.0/bearer")
v.Set("assertion", signedJWT)
rawResp, err := u.client.PostForm(fmt.Sprintf("%s/token", oauth2Endpoint), v)
if err != nil {
return nil, fmt.Errorf("impersonate: unable to exchange token: %v", err)
}
body, err := ioutil.ReadAll(io.LimitReader(rawResp.Body, 1<<20))
if err != nil {
return nil, fmt.Errorf("impersonate: unable to read body: %v", err)
}
if c := rawResp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("impersonate: status code %d: %s", c, body)
}
var tokenResp exchangeTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("impersonate: unable to parse response: %v", err)
}
return &oauth2.Token{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
Expiry: now.Add(time.Second * time.Duration(tokenResp.ExpiresIn)),
}, nil
}

@ -4,7 +4,6 @@ cloud.google.com/go
cloud.google.com/go/internal
cloud.google.com/go/internal/optional
cloud.google.com/go/internal/pubsub
cloud.google.com/go/internal/testutil
cloud.google.com/go/internal/trace
cloud.google.com/go/internal/version
# cloud.google.com/go/bigtable v1.18.1
@ -36,7 +35,6 @@ cloud.google.com/go/pubsub/apiv1/pubsubpb
cloud.google.com/go/pubsub/internal
cloud.google.com/go/pubsub/internal/distribution
cloud.google.com/go/pubsub/internal/scheduler
cloud.google.com/go/pubsub/pstest
# cloud.google.com/go/storage v1.29.0
## explicit; go 1.19
cloud.google.com/go/storage
@ -1464,7 +1462,6 @@ google.golang.org/api/compute/v1
google.golang.org/api/googleapi
google.golang.org/api/googleapi/transport
google.golang.org/api/iamcredentials/v1
google.golang.org/api/impersonate
google.golang.org/api/internal
google.golang.org/api/internal/gensupport
google.golang.org/api/internal/impersonate

Loading…
Cancel
Save