Skip to content

Commit

Permalink
Don't use globals
Browse files Browse the repository at this point in the history
  • Loading branch information
dricross committed Jan 31, 2025
1 parent 7e47531 commit 20257af
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
9 changes: 4 additions & 5 deletions cfg/aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,6 @@ const (
SourceAccountHeaderKey = "x-amz-source-account"
)

var (
sourceAccount = os.Getenv(envconfig.AmzSourceAccount) // populates the "x-amz-source-account" header
sourceArn = os.Getenv(envconfig.AmzSourceArn) // populates the "x-amz-source-arn" header
)

// newStsClient creates a new STS client with the provided config and options.
// Additionally, if specific environment variables are set, it also appends the confused deputy headers to requests
// made by the client. These headers allow resource-based policies to limit the permissions that a service has to
Expand All @@ -223,6 +218,10 @@ var (
//
// See https://docs.aws.amazon.com/IAM/latest/UserGuide/confused-deputy.html#cross-service-confused-deputy-prevention
func newStsClient(p client.ConfigProvider, cfgs ...*aws.Config) *sts.STS {

sourceAccount := os.Getenv(envconfig.AmzSourceAccount)
sourceArn := os.Getenv(envconfig.AmzSourceArn)

client := sts.New(p, cfgs...)
if sourceAccount != "" && sourceArn != "" {
client.Handlers.Sign.PushFront(func(r *request.Request) {
Expand Down
14 changes: 7 additions & 7 deletions cfg/aws/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/aws/amazon-cloudwatch-agent/cfg/envconfig"
)

func TestConfusedDeputyHeaders(t *testing.T) {
Expand Down Expand Up @@ -53,9 +55,9 @@ func TestConfusedDeputyHeaders(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Set global variables which will get picked up by newStsClient
sourceArn = tt.envSourceArn
sourceAccount = tt.envSourceAccount

t.Setenv(envconfig.AmzSourceAccount, tt.envSourceAccount)
t.Setenv(envconfig.AmzSourceArn, tt.envSourceArn)

client := newStsClient(mock.Session, &aws.Config{
// These are examples credentials pulled from:
Expand All @@ -76,14 +78,12 @@ func TestConfusedDeputyHeaders(t *testing.T) {
err := request.Sign()
require.NoError(t, err)

headerSourceArn := request.HTTPRequest.Header.Get("x-amz-source-arn")
headerSourceArn := request.HTTPRequest.Header.Get(SourceArnHeaderKey)
assert.Equal(t, tt.expectedHeaderArn, headerSourceArn)

headerSourceAccount := request.HTTPRequest.Header.Get("x-amz-source-account")
headerSourceAccount := request.HTTPRequest.Header.Get(SourceAccountHeaderKey)
assert.Equal(t, tt.expectedHeaderAccount, headerSourceAccount)
})
}

sourceArn = ""
sourceAccount = ""
}

0 comments on commit 20257af

Please sign in to comment.