diff --git a/go.mod b/go.mod
index 177fceb4fcf1d..50716677616e2 100644
--- a/go.mod
+++ b/go.mod
@@ -293,6 +293,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect
+ github.com/aws/aws-sdk-go-v2/service/lambda v1.70.1 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.25.0 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
diff --git a/go.sum b/go.sum
index 9c04612ef6dec..29e702d79fba5 100644
--- a/go.sum
+++ b/go.sum
@@ -926,6 +926,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91Liq
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA=
github.com/aws/aws-sdk-go-v2/service/kms v1.38.0 h1:+2/0Cq0R/audJhwM1GpJMg8X1TTrMKDFRLO5RMaNRU0=
github.com/aws/aws-sdk-go-v2/service/kms v1.38.0/go.mod h1:cQn6tAF77Di6m4huxovNM7NVAozWTZLsDRp9t8Z/WYk=
+github.com/aws/aws-sdk-go-v2/service/lambda v1.70.1 h1:EabaKQAptxXAeSL0sXKqfupPe/CpH965wqoloUK0aMM=
+github.com/aws/aws-sdk-go-v2/service/lambda v1.70.1/go.mod h1:c27kk10S36lBYgbG1jR3opn4OAS5Y/4wjJa1GiHK/X4=
github.com/aws/aws-sdk-go-v2/service/memorydb v1.26.0 h1:nO9RCZnfAIF5q43IDLWtf7vu/l16RKzeTkv5GObkyME=
github.com/aws/aws-sdk-go-v2/service/memorydb v1.26.0/go.mod h1:pfuDC5zBwunXdE44WT1PRbtzuXWGohKFcFLtv+ezI6k=
github.com/aws/aws-sdk-go-v2/service/opensearch v1.46.0 h1:eR65kYpNlKpGkkvg+A83hc0hpk2CHappaz1JAUCcxVs=
diff --git a/lib/events/fips.go b/lib/events/fips.go
index b00a1910df111..f0feb6d58958e 100644
--- a/lib/events/fips.go
+++ b/lib/events/fips.go
@@ -19,7 +19,7 @@
package events
import (
- "github.com/aws/aws-sdk-go/aws/endpoints"
+ "github.com/aws/aws-sdk-go-v2/aws"
"github.com/gravitational/teleport/api/types"
)
@@ -31,14 +31,14 @@ const (
)
var (
- fipsToAWS = map[types.ClusterAuditConfigSpecV2_FIPSEndpointState]endpoints.FIPSEndpointState{
- types.ClusterAuditConfigSpecV2_FIPS_UNSET: endpoints.FIPSEndpointStateUnset,
- types.ClusterAuditConfigSpecV2_FIPS_ENABLED: endpoints.FIPSEndpointStateEnabled,
- types.ClusterAuditConfigSpecV2_FIPS_DISABLED: endpoints.FIPSEndpointStateDisabled,
+ fipsToAWS = map[types.ClusterAuditConfigSpecV2_FIPSEndpointState]aws.FIPSEndpointState{
+ types.ClusterAuditConfigSpecV2_FIPS_UNSET: aws.FIPSEndpointStateUnset,
+ types.ClusterAuditConfigSpecV2_FIPS_ENABLED: aws.FIPSEndpointStateEnabled,
+ types.ClusterAuditConfigSpecV2_FIPS_DISABLED: aws.FIPSEndpointStateDisabled,
}
)
// FIPSProtoStateToAWSState converts a FIPS proto state to an aws endpoints.FIPSEndpointState
-func FIPSProtoStateToAWSState(state types.ClusterAuditConfigSpecV2_FIPSEndpointState) endpoints.FIPSEndpointState {
+func FIPSProtoStateToAWSState(state types.ClusterAuditConfigSpecV2_FIPSEndpointState) aws.FIPSEndpointState {
return fipsToAWS[state]
}
diff --git a/lib/integrations/awsoidc/clientsv1.go b/lib/integrations/awsoidc/clientsv1.go
deleted file mode 100644
index 576badf79951d..0000000000000
--- a/lib/integrations/awsoidc/clientsv1.go
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package awsoidc
-
-import (
- "context"
-
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/credentials"
- "github.com/aws/aws-sdk-go/aws/credentials/stscreds"
- "github.com/aws/aws-sdk-go/aws/endpoints"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/gravitational/trace"
-
- "github.com/gravitational/teleport/api/types"
- utilsaws "github.com/gravitational/teleport/api/utils/aws"
- "github.com/gravitational/teleport/lib/modules"
- "github.com/gravitational/teleport/lib/utils/aws/stsutils"
-)
-
-// FetchToken returns the token.
-func (j IdentityToken) FetchToken(ctx credentials.Context) ([]byte, error) {
- return []byte(j), nil
-}
-
-// IntegrationTokenGenerator is an interface that indicates which APIs are required to generate an Integration Token.
-type IntegrationTokenGenerator interface {
- // GetIntegration returns the specified integration resources.
- GetIntegration(ctx context.Context, name string) (types.Integration, error)
-
- // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action.
- GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error)
-}
-
-// NewSessionV1 creates a new AWS Session for the region using the integration as source of credentials.
-// This session is usable for AWS SDK Go V1.
-func NewSessionV1(ctx context.Context, client IntegrationTokenGenerator, region string, integrationName string) (*session.Session, error) {
- if region != "" {
- if err := utilsaws.IsValidRegion(region); err != nil {
- return nil, trace.Wrap(err)
- }
- }
- integration, err := client.GetIntegration(ctx, integrationName)
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- awsOIDCIntegration := integration.GetAWSOIDCIntegrationSpec()
- if awsOIDCIntegration == nil {
- return nil, trace.BadParameter("invalid integration subkind, expected awsoidc, got %s", integration.GetSubKind())
- }
-
- useFIPSEndpoint := endpoints.FIPSEndpointStateUnset
- if modules.GetModules().IsBoringBinary() {
- useFIPSEndpoint = endpoints.FIPSEndpointStateEnabled
- }
-
- sess, err := session.NewSessionWithOptions(session.Options{
- SharedConfigState: session.SharedConfigDisable,
- Config: aws.Config{
- UseFIPSEndpoint: useFIPSEndpoint,
- },
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- // AWS SDK calls FetchToken everytime the session is no longer valid (or there's no session).
- // Generating a token here and using it as a Static would make this token valid for the Max Duration Session for the current AWS Role (usually, 1 hour).
- // Instead, it generates a token everytime the Session's client requests a new token, ensuring it always receives a fresh one.
- var integrationTokenFetcher IntegrationTokenFetcher = func(ctx context.Context) ([]byte, error) {
- token, err := client.GenerateAWSOIDCToken(ctx, integrationName)
- return []byte(token), trace.Wrap(err)
- }
-
- stsSTS := stsutils.NewV1(sess)
- roleProvider := stscreds.NewWebIdentityRoleProviderWithOptions(
- stsSTS,
- awsOIDCIntegration.RoleARN,
- "",
- integrationTokenFetcher,
- )
- awsCredentials := credentials.NewCredentials(roleProvider)
-
- session, err := session.NewSessionWithOptions(session.Options{
- SharedConfigState: session.SharedConfigDisable,
- Config: aws.Config{
- Region: aws.String(region),
- Credentials: awsCredentials,
- UseFIPSEndpoint: useFIPSEndpoint,
- },
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- return session, nil
-}
-
-// IntegrationTokenFetcher handles dynamic token generation using a callback function.
-// Useful to embed as a [stscreds.TokenFetcher].
-type IntegrationTokenFetcher func(context.Context) ([]byte, error)
-
-// FetchToken returns a token by calling the callback function.
-func (genFn IntegrationTokenFetcher) FetchToken(ctx context.Context) ([]byte, error) {
- token, err := genFn(ctx)
- return token, trace.Wrap(err)
-}
diff --git a/lib/integrations/awsoidc/clientsv1_test.go b/lib/integrations/awsoidc/clientsv1_test.go
deleted file mode 100644
index 66a350ebe67a6..0000000000000
--- a/lib/integrations/awsoidc/clientsv1_test.go
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package awsoidc
-
-import (
- "context"
- "testing"
-
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/google/uuid"
- "github.com/gravitational/trace"
- "github.com/stretchr/testify/require"
-
- "github.com/gravitational/teleport/api/types"
-)
-
-type mockIntegrationsTokenGenerator struct {
- proxies []types.Server
- integrations map[string]types.Integration
- tokenCallsCount int
-}
-
-// GetIntegration returns the specified integration resources.
-func (m *mockIntegrationsTokenGenerator) GetIntegration(ctx context.Context, name string) (types.Integration, error) {
- if ig, found := m.integrations[name]; found {
- return ig, nil
- }
-
- return nil, trace.NotFound("integration not found")
-}
-
-// GetProxies returns a list of registered proxies.
-func (m *mockIntegrationsTokenGenerator) GetProxies() ([]types.Server, error) {
- return m.proxies, nil
-}
-
-// GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action.
-func (m *mockIntegrationsTokenGenerator) GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) {
- m.tokenCallsCount++
- return uuid.NewString(), nil
-}
-
-func TestNewSessionV1(t *testing.T) {
- ctx := context.Background()
-
- dummyIntegration, err := types.NewIntegrationAWSOIDC(
- types.Metadata{Name: "myawsintegration"},
- &types.AWSOIDCIntegrationSpecV1{
- RoleARN: "arn:aws:sts::123456789012:role/TestRole",
- },
- )
- require.NoError(t, err)
-
- dummyProxy, err := types.NewServer(
- "proxy-123", types.KindProxy,
- types.ServerSpecV2{
- PublicAddrs: []string{"https://localhost:3080/"},
- },
- )
- require.NoError(t, err)
-
- for _, tt := range []struct {
- name string
- region string
- integration string
- tokenFetchCount int
- expectedErr require.ErrorAssertionFunc
- sessionValidator func(*testing.T, *session.Session)
- }{
- {
- name: "valid",
- region: "us-dummy-1",
- integration: "myawsintegration",
- expectedErr: require.NoError,
- sessionValidator: func(t *testing.T, s *session.Session) {
- require.Equal(t, aws.String("us-dummy-1"), s.Config.Region)
- },
- },
- {
- name: "valid with empty region",
- region: "",
- integration: "myawsintegration",
- expectedErr: require.NoError,
- sessionValidator: func(t *testing.T, s *session.Session) {
- require.Equal(t, "", aws.StringValue(s.Config.Region))
- },
- },
- {
- name: "not found error when integration is missing",
- region: "us-dummy-1",
- integration: "not-found",
- expectedErr: notFoundCheck,
- },
- } {
- t.Run(tt.name, func(t *testing.T) {
- mockTokenGenertor := &mockIntegrationsTokenGenerator{
- proxies: []types.Server{dummyProxy},
- integrations: map[string]types.Integration{
- dummyIntegration.GetName(): dummyIntegration,
- },
- }
- awsSessionOut, err := NewSessionV1(ctx, mockTokenGenertor, tt.region, tt.integration)
-
- tt.expectedErr(t, err)
- if tt.sessionValidator != nil {
- tt.sessionValidator(t, awsSessionOut)
- }
- require.Zero(t, tt.tokenFetchCount)
- })
- }
-
-}
diff --git a/lib/integrations/externalauditstorage/configurator_test.go b/lib/integrations/externalauditstorage/configurator_test.go
index 631fae9337aed..df845e15a1e15 100644
--- a/lib/integrations/externalauditstorage/configurator_test.go
+++ b/lib/integrations/externalauditstorage/configurator_test.go
@@ -25,9 +25,9 @@ import (
"testing"
"time"
+ "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
- "github.com/aws/aws-sdk-go/aws"
"github.com/google/uuid"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go
index 1e01cfed5606d..083001fd30bae 100644
--- a/lib/srv/alpnproxy/aws_local_proxy_test.go
+++ b/lib/srv/alpnproxy/aws_local_proxy_test.go
@@ -19,8 +19,6 @@
package alpnproxy
import (
- "context"
- "encoding/xml"
"net/http"
"net/http/httptest"
"testing"
@@ -28,7 +26,6 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
- "github.com/aws/aws-sdk-go-v2/service/sts"
ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/stretchr/testify/require"
@@ -39,8 +36,8 @@ func TestAWSAccessMiddleware(t *testing.T) {
t.Parallel()
assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name"
- localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "")
- assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token")
+ localCred := aws.Credentials{AccessKeyID: "local-proxy", SecretAccessKey: "local-proxy-secret"}
+ assumedRoleCred := aws.Credentials{AccessKeyID: "assumed-role", SecretAccessKey: "assumed-role-secret", SessionToken: "assumed-role-token"}
m := &AWSAccessMiddleware{
AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""),
@@ -48,11 +45,10 @@ func TestAWSAccessMiddleware(t *testing.T) {
require.NoError(t, m.CheckAndSetDefaults())
stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil)
-
- awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now())
+ awsutils.NewSigner("sts").SignHTTP(t.Context(), localCred, stsRequestByLocalProxyCred, awsutils.EmptyPayloadHash, "sts", "us-west-1", time.Now())
requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil)
- awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now())
+ awsutils.NewSigner("s3").SignHTTP(t.Context(), assumedRoleCred, requestByAssumedRole, awsutils.EmptyPayloadHash, "s3", "us-west-1", time.Now())
t.Run("request no authorization", func(t *testing.T) {
recorder := httptest.NewRecorder()
@@ -99,34 +95,60 @@ func TestAWSAccessMiddleware(t *testing.T) {
})
}
-func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response {
- t.Helper()
+// IdentityResult represents the identitiy result of an AWS response.
+type IdentityResult struct {
+ ARN string `xml:"Arn"`
+}
- credValue, err := provider.Retrieve(context.Background())
- require.NoError(t, err)
+// ResponseMetadata contains the metadata of a AWS response.
+type ResponseMetadata struct {
+ RequestID string `xml:"RequestID"`
+ StatusCode int `xml:"StatusCode"`
+}
- body, err := awsutils.MarshalXML(
- xml.Name{
- Local: "AssumeRoleResponse",
- Space: "https://sts.amazonaws.com/doc/2011-06-15/",
- },
- map[string]any{
- "AssumeRoleResult": sts.AssumeRoleOutput{
- AssumedRoleUser: &ststypes.AssumedRoleUser{
- Arn: aws.String(roleARN),
- },
- Credentials: &ststypes.Credentials{
- AccessKeyId: aws.String(credValue.AccessKeyID),
- SecretAccessKey: aws.String(credValue.SecretAccessKey),
- SessionToken: aws.String(credValue.SessionToken),
- },
+// AssumeRoleResult contains the assume role result.
+type AssumeRoleResult struct {
+ // AssumedRoleUser is the assumed user.
+ AssumedRoleUser IdentityResult `xml:"AssumedRoleUser"`
+ // Credentials is the generated credentials.
+ Credentials ststypes.Credentials `xml:"Credentials"`
+}
+
+// AssumeRoleResponse is the response of assume role.
+type AssumeRoleResponse struct {
+ // AssumeRoleResult is the resulting response from assume role.
+ AssumeRoleResult AssumeRoleResult `xml:"AssumeRoleResult"`
+ // Response is the response metadata.
+ Response ResponseMetadata `xml:"ResponseMetadata"`
+}
+
+// GetCallerIdentityResponse is the response of get caller identity call.
+type GetCallerIdentityResponse struct {
+ // AssumeRoleResult is the resulting response from assume role.
+ GetCallerIdentityResult IdentityResult `xml:"GetCallerIdentityResult"`
+ // Response is the response metadata.
+ Response ResponseMetadata `xml:"ResponseMetadata"`
+}
+
+func assumeRoleResponse(t *testing.T, roleARN string, creds aws.Credentials) *http.Response {
+ t.Helper()
+
+ body, err := awsutils.MarshalXML("AssumeRoleResponse", "https://sts.amazonaws.com/doc/2011-06-15/", AssumeRoleResponse{
+ AssumeRoleResult: AssumeRoleResult{
+ AssumedRoleUser: IdentityResult{
+ ARN: roleARN,
},
- "ResponseMetadata": map[string]any{
- "StatusCode": http.StatusOK,
- "RequestID": "22222222-3333-3333-3333-333333333333",
+ Credentials: ststypes.Credentials{
+ AccessKeyId: aws.String(creds.AccessKeyID),
+ SecretAccessKey: aws.String(creds.SecretAccessKey),
+ SessionToken: aws.String(creds.SessionToken),
},
},
- )
+ Response: ResponseMetadata{
+ StatusCode: http.StatusOK,
+ RequestID: "22222222-3333-3333-3333-333333333333",
+ },
+ })
require.NoError(t, err)
return fakeHTTPResponse(http.StatusOK, body)
}
@@ -134,21 +156,15 @@ func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsPr
func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response {
t.Helper()
- body, err := awsutils.MarshalXML(
- xml.Name{
- Local: "GetCallerIdentityResponse",
- Space: "https://sts.amazonaws.com/doc/2011-06-15/",
+ body, err := awsutils.MarshalXML("GetCallerIdentityResponse", "https://sts.amazonaws.com/doc/2011-06-15/", GetCallerIdentityResponse{
+ GetCallerIdentityResult: IdentityResult{
+ ARN: roleARN,
},
- map[string]any{
- "GetCallerIdentityResult": sts.GetCallerIdentityOutput{
- Arn: aws.String(roleARN),
- },
- "ResponseMetadata": map[string]any{
- "StatusCode": http.StatusOK,
- "RequestID": "22222222-3333-3333-3333-333333333333",
- },
+ Response: ResponseMetadata{
+ StatusCode: http.StatusOK,
+ RequestID: "22222222-3333-3333-3333-333333333333",
},
- )
+ })
require.NoError(t, err)
return fakeHTTPResponse(http.StatusOK, body)
}
diff --git a/lib/srv/app/aws/endpoints_test.go b/lib/srv/app/aws/endpoints_test.go
index 2690482b71ad1..7a1116be35430 100644
--- a/lib/srv/app/aws/endpoints_test.go
+++ b/lib/srv/app/aws/endpoints_test.go
@@ -19,20 +19,20 @@
package aws
import (
- "bytes"
"net/http"
"testing"
"time"
- "github.com/aws/aws-sdk-go/aws/credentials"
- v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/stretchr/testify/require"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)
func TestResolveEndpoints(t *testing.T) {
- signer := v4.NewSigner(credentials.NewStaticCredentials("fakeClientKeyID", "fakeClientSecret", ""))
+ creds := aws.Credentials{AccessKeyID: "fakeClientKeyID", SecretAccessKey: "fakeClientSecret"}
+ signer := v4.NewSigner()
region := "us-east-1"
now := time.Now()
@@ -40,7 +40,7 @@ func TestResolveEndpoints(t *testing.T) {
req, err := http.NewRequest("GET", "http://localhost", nil)
require.NoError(t, err)
- _, err = signer.Sign(req, bytes.NewReader(nil), "ecr", "us-east-1", now)
+ err = signer.SignHTTP(t.Context(), creds, req, awsutils.EmptyPayloadHash, "ecr", "us-east-1", now)
require.NoError(t, err)
_, err = resolveEndpoint(req, awsutils.AuthorizationHeader)
@@ -52,7 +52,7 @@ func TestResolveEndpoints(t *testing.T) {
require.NoError(t, err)
req.Header.Set("X-Forwarded-Host", "some-service.us-east-1.amazonaws.com")
- _, err = signer.Sign(req, bytes.NewReader(nil), "some-service", region, now)
+ err = signer.SignHTTP(t.Context(), creds, req, awsutils.EmptyPayloadHash, "some-service", region, now)
require.NoError(t, err)
endpoint, err := resolveEndpoint(req, awsutils.AuthorizationHeader)
diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go
index 2ad691537fa4a..476ad04e42c1c 100644
--- a/lib/srv/app/aws/handler.go
+++ b/lib/srv/app/aws/handler.go
@@ -165,6 +165,11 @@ func (s *signerHandler) serveCommonRequest(sessCtx *common.SessionContext, w htt
return trace.Wrap(err)
}
+ reqCloneForAudit, err := cloneRequest(unsignedReq)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
awsCfg, err := s.AWSConfigProvider.GetConfig(s.closeContext, re.SigningRegion,
awsconfig.WithDetailedAssumeRole(awsconfig.AssumeRole{
RoleARN: sessCtx.Identity.RouteToApp.AWSRoleARN,
@@ -189,7 +194,7 @@ func (s *signerHandler) serveCommonRequest(sessCtx *common.SessionContext, w htt
}
recorder := httplib.NewResponseStatusRecorder(w)
s.fwd.ServeHTTP(recorder, signedReq)
- s.emitAudit(sessCtx, unsignedReq, uint32(recorder.Status()), re)
+ s.emitAudit(sessCtx, reqCloneForAudit, uint32(recorder.Status()), re)
return nil
}
diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go
index 182de8945b7f8..e51210582db35 100644
--- a/lib/srv/app/aws/handler_test.go
+++ b/lib/srv/app/aws/handler_test.go
@@ -31,16 +31,13 @@ import (
"testing"
"time"
- credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/awserr"
- "github.com/aws/aws-sdk-go/aws/client"
- "github.com/aws/aws-sdk-go/aws/credentials"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/aws/aws-sdk-go/service/dynamodb"
- "github.com/aws/aws-sdk-go/service/lambda"
- "github.com/aws/aws-sdk-go/service/s3"
- "github.com/aws/aws-sdk-go/service/sts"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ transporthttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
+ "github.com/aws/aws-sdk-go-v2/credentials"
+ "github.com/aws/aws-sdk-go-v2/service/dynamodb"
+ "github.com/aws/aws-sdk-go-v2/service/lambda"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
@@ -66,47 +63,51 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
-type makeRequest func(url string, provider client.ConfigProvider, awsHost string) error
+type makeRequest func(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error
-func s3Request(url string, provider client.ConfigProvider, awsHost string) error {
- return s3RequestWithTransport(url, provider, &requestByHTTPSProxy{xForwardedHost: awsHost})
+func s3Request(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
+ return s3RequestWithTransport(ctx, url, region, provider, &requestByHTTPSProxy{xForwardedHost: awsHost})
}
-func s3RequestByAssumedRole(url string, provider client.ConfigProvider, awsHost string) error {
- return s3RequestWithTransport(url, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost})
+func s3RequestByAssumedRole(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
+ return s3RequestWithTransport(ctx, url, region, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost})
}
-func s3RequestWithTransport(url string, provider client.ConfigProvider, transport http.RoundTripper) error {
- s3Client := s3.New(provider, &aws.Config{
- Endpoint: &url,
- MaxRetries: aws.Int(0),
+func s3RequestWithTransport(ctx context.Context, url string, region string, provider aws.CredentialsProvider, transport http.RoundTripper) error {
+ s3Client := s3.New(s3.Options{
+ Credentials: provider,
+ BaseEndpoint: &url,
+ Region: region,
+ RetryMaxAttempts: 0,
HTTPClient: &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
},
})
- _, err := s3Client.ListBuckets(&s3.ListBucketsInput{})
+ _, err := s3Client.ListBuckets(ctx, &s3.ListBucketsInput{})
return err
}
-func dynamoRequest(url string, provider client.ConfigProvider, awsHost string) error {
- return dynamoRequestWithTransport(url, provider, &requestByHTTPSProxy{xForwardedHost: awsHost})
+func dynamoRequest(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
+ return dynamoRequestWithTransport(ctx, url, region, provider, &requestByHTTPSProxy{xForwardedHost: awsHost})
}
-func dynamoRequestByAssumedRole(url string, provider client.ConfigProvider, awsHost string) error {
- return dynamoRequestWithTransport(url, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost})
+func dynamoRequestByAssumedRole(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
+ return dynamoRequestWithTransport(ctx, url, region, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost})
}
-func dynamoRequestWithTransport(url string, provider client.ConfigProvider, transport http.RoundTripper) error {
- dynamoClient := dynamodb.New(provider, &aws.Config{
- Endpoint: &url,
- MaxRetries: aws.Int(0),
+func dynamoRequestWithTransport(ctx context.Context, url string, region string, provider aws.CredentialsProvider, transport http.RoundTripper) error {
+ dynamoClient := dynamodb.New(dynamodb.Options{
+ Credentials: provider,
+ BaseEndpoint: &url,
+ Region: region,
+ RetryMaxAttempts: 0,
HTTPClient: &http.Client{
Transport: transport,
Timeout: 5 * time.Second,
},
})
- _, err := dynamoClient.Scan(&dynamodb.ScanInput{
+ _, err := dynamoClient.Scan(ctx, &dynamodb.ScanInput{
TableName: aws.String("test-table"),
})
return err
@@ -116,30 +117,32 @@ func dynamoRequestWithTransport(url string, provider client.ConfigProvider, tran
// size. Use a 1MB limit instead of the actual 70MB limit.
const maxTestHTTPRequestBodySize = 1 << 20
-func maxSizeExceededRequest(url string, provider client.ConfigProvider, awsHost string) error {
+func maxSizeExceededRequest(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
// fake an upload that's too large
payload := strings.Repeat("x", maxTestHTTPRequestBodySize)
- return lambdaRequestWithPayload(url, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost})
+ return lambdaRequestWithPayload(ctx, url, region, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost})
}
-func lambdaRequest(url string, provider client.ConfigProvider, awsHost string) error {
+func lambdaRequest(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
// fake a zip file with 70% of the max limit. Lambda will base64 encode it,
// which bloats it up, and our proxy should still handle it.
const size = (maxTestHTTPRequestBodySize * 7) / 10
payload := strings.Repeat("x", size)
- return lambdaRequestWithPayload(url, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost})
+ return lambdaRequestWithPayload(ctx, url, region, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost})
}
-func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payload string, transport http.RoundTripper) error {
- lambdaClient := lambda.New(provider, &aws.Config{
- Endpoint: &url,
- MaxRetries: aws.Int(0),
+func lambdaRequestWithPayload(ctx context.Context, url string, region string, provider aws.CredentialsProvider, payload string, transport http.RoundTripper) error {
+ lambdaClient := lambda.New(lambda.Options{
+ Credentials: provider,
+ BaseEndpoint: &url,
+ Region: region,
+ RetryMaxAttempts: 0,
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
Transport: transport,
},
})
- _, err := lambdaClient.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{
+ _, err := lambdaClient.UpdateFunctionCode(ctx, &lambda.UpdateFunctionCodeInput{
FunctionName: aws.String("fakeFunc"),
ZipFile: []byte(payload),
})
@@ -147,17 +150,19 @@ func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payloa
}
func assumeRoleRequest(requestDuration time.Duration) makeRequest {
- return func(url string, provider client.ConfigProvider, awsHost string) error {
- stsClient := stsutils.NewV1(provider, &aws.Config{
- Endpoint: &url,
- MaxRetries: aws.Int(0),
+ return func(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error {
+ stsClient := stsutils.NewFromConfig(aws.Config{
+ Credentials: provider,
+ BaseEndpoint: &url,
+ Region: region,
+ RetryMaxAttempts: 0,
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
Transport: &requestByHTTPSProxy{xForwardedHost: awsHost},
},
})
- _, err := stsClient.AssumeRole(&sts.AssumeRoleInput{
- DurationSeconds: aws.Int64(int64(requestDuration.Seconds())),
+ _, err := stsClient.AssumeRole(ctx, &sts.AssumeRoleInput{
+ DurationSeconds: aws.Int32(int32(requestDuration.Seconds())),
RoleSessionName: aws.String("test-session"),
RoleArn: aws.String("arn:aws:iam::123456789012:role/test-role"),
})
@@ -191,9 +196,9 @@ func (r requestByAssumedRoleTransport) RoundTrip(req *http.Request) (*http.Respo
func hasStatusCode(wantStatusCode int) require.ErrorAssertionFunc {
return func(t require.TestingT, err error, msgAndArgs ...interface{}) {
- var apiErr awserr.RequestFailure
- require.ErrorAs(t, err, &apiErr, msgAndArgs...)
- require.Equal(t, wantStatusCode, apiErr.StatusCode(), msgAndArgs...)
+ var respErr *transporthttp.ResponseError
+ require.ErrorAs(t, err, &respErr, msgAndArgs...)
+ require.Equal(t, wantStatusCode, respErr.Response.StatusCode, msgAndArgs...)
}
}
@@ -224,47 +229,44 @@ func TestAWSSignerHandler(t *testing.T) {
require.NoError(t, err)
tests := []struct {
- name string
- app types.Application
- awsClientSession *session.Session
- awsConfigProvider awsconfig.Provider
- request makeRequest
- advanceClock time.Duration
- wantHost string
- wantAuthCredService string
- wantAuthCredRegion string
- wantAuthCredKeyID string
- wantEventType events.AuditEvent
- wantAssumedRole string
- skipVerifySignature bool
- verifySentRequest func(*testing.T, *http.Request)
- errAssertionFns []require.ErrorAssertionFunc
+ name string
+ app types.Application
+ awsCredentialsProvider aws.CredentialsProvider
+ awsRegion string
+ awsConfigProvider awsconfig.Provider
+ request makeRequest
+ advanceClock time.Duration
+ wantHost string
+ wantAuthCredService string
+ wantAuthCredRegion string
+ wantAuthCredKeyID string
+ wantEventType events.AuditEvent
+ wantAssumedRole string
+ skipVerifySignature bool
+ verifySentRequest func(*testing.T, *http.Request)
+ errAssertionFns []require.ErrorAssertionFunc
}{
{
- name: "s3 access",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-west-2"),
- })),
- request: s3Request,
- wantHost: "s3.us-west-2.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "s3",
- wantAuthCredRegion: "us-west-2",
- wantEventType: &events.AppSessionRequest{},
+ name: "s3 access",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-west-2",
+ request: s3Request,
+ wantHost: "s3.us-west-2.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "s3",
+ wantAuthCredRegion: "us-west-2",
+ wantEventType: &events.AppSessionRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "s3 access with integration",
- app: consoleAppWithIntegration,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-west-2"),
- })),
- request: s3Request,
+ name: "s3 access with integration",
+ app: consoleAppWithIntegration,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-west-2",
+ request: s3Request,
awsConfigProvider: &mocks.AWSConfigProvider{
OIDCIntegrationClient: &mocks.FakeOIDCIntegrationClient{
Integration: awsOIDCIntegration,
@@ -281,144 +283,126 @@ func TestAWSSignerHandler(t *testing.T) {
},
},
{
- name: "s3 access with different region",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-west-1"),
- })),
- request: s3Request,
- wantHost: "s3.us-west-1.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "s3",
- wantAuthCredRegion: "us-west-1",
- wantEventType: &events.AppSessionRequest{},
+ name: "s3 access with different region",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-west-1",
+ request: s3Request,
+ wantHost: "s3.us-west-1.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "s3",
+ wantAuthCredRegion: "us-west-1",
+ wantEventType: &events.AppSessionRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "s3 access missing credentials",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: credentials.AnonymousCredentials,
- Region: aws.String("us-west-1"),
- })),
- request: s3Request,
+ name: "s3 access missing credentials",
+ app: consoleApp,
+ awsCredentialsProvider: aws.AnonymousCredentials{},
+ awsRegion: "us-west-1",
+ request: s3Request,
errAssertionFns: []require.ErrorAssertionFunc{
hasStatusCode(http.StatusBadRequest),
},
},
{
- name: "s3 access by assumed role",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForAssumedRole,
- Region: aws.String("us-west-2"),
- })),
- request: s3RequestByAssumedRole,
- wantHost: "s3.us-west-2.amazonaws.com",
- wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID
- wantAuthCredService: "s3",
- wantAuthCredRegion: "us-west-2",
- wantEventType: &events.AppSessionRequest{},
- wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit
- skipVerifySignature: true, // not re-signing
+ name: "s3 access by assumed role",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForAssumedRole,
+ awsRegion: "us-west-2",
+ request: s3RequestByAssumedRole,
+ wantHost: "s3.us-west-2.amazonaws.com",
+ wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID
+ wantAuthCredService: "s3",
+ wantAuthCredRegion: "us-west-2",
+ wantEventType: &events.AppSessionRequest{},
+ wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit
+ skipVerifySignature: true, // not re-signing
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "DynamoDB access",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-east-1"),
- })),
- request: dynamoRequest,
- wantHost: "dynamodb.us-east-1.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "dynamodb",
- wantAuthCredRegion: "us-east-1",
- wantEventType: &events.AppSessionDynamoDBRequest{},
+ name: "DynamoDB access",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-east-1",
+ request: dynamoRequest,
+ wantHost: "dynamodb.us-east-1.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "dynamodb",
+ wantAuthCredRegion: "us-east-1",
+ wantEventType: &events.AppSessionDynamoDBRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "DynamoDB access with different region",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-west-1"),
- })),
- request: dynamoRequest,
- wantHost: "dynamodb.us-west-1.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "dynamodb",
- wantAuthCredRegion: "us-west-1",
- wantEventType: &events.AppSessionDynamoDBRequest{},
+ name: "DynamoDB access with different region",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-west-1",
+ request: dynamoRequest,
+ wantHost: "dynamodb.us-west-1.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "dynamodb",
+ wantAuthCredRegion: "us-west-1",
+ wantEventType: &events.AppSessionDynamoDBRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "DynamoDB access missing credentials",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: credentials.AnonymousCredentials,
- Region: aws.String("us-west-1"),
- })),
- request: dynamoRequest,
+ name: "DynamoDB access missing credentials",
+ app: consoleApp,
+ awsCredentialsProvider: aws.AnonymousCredentials{},
+ awsRegion: "us-west-1",
+ request: dynamoRequest,
errAssertionFns: []require.ErrorAssertionFunc{
hasStatusCode(http.StatusBadRequest),
},
},
{
- name: "DynamoDB access by assumed role",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForAssumedRole,
- Region: aws.String("us-east-1"),
- })),
- request: dynamoRequestByAssumedRole,
- wantHost: "dynamodb.us-east-1.amazonaws.com",
- wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID
- wantAuthCredService: "dynamodb",
- wantAuthCredRegion: "us-east-1",
- wantEventType: &events.AppSessionDynamoDBRequest{},
- wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit
- skipVerifySignature: true, // not re-signing
+ name: "DynamoDB access by assumed role",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForAssumedRole,
+ awsRegion: "us-east-1",
+ request: dynamoRequestByAssumedRole,
+ wantHost: "dynamodb.us-east-1.amazonaws.com",
+ wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID
+ wantAuthCredService: "dynamodb",
+ wantAuthCredRegion: "us-east-1",
+ wantEventType: &events.AppSessionDynamoDBRequest{},
+ wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit
+ skipVerifySignature: true, // not re-signing
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "Lambda access",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-east-1"),
- })),
- request: lambdaRequest,
- wantHost: "lambda.us-east-1.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "lambda",
- wantAuthCredRegion: "us-east-1",
- wantEventType: &events.AppSessionRequest{},
+ name: "Lambda access",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-east-1",
+ request: lambdaRequest,
+ wantHost: "lambda.us-east-1.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "lambda",
+ wantAuthCredRegion: "us-east-1",
+ wantEventType: &events.AppSessionRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "Request exceeding max size",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-east-1"),
- })),
- request: maxSizeExceededRequest,
- wantHost: "lambda.us-east-1.amazonaws.com",
+ name: "Request exceeding max size",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-east-1",
+ request: maxSizeExceededRequest,
+ wantHost: "lambda.us-east-1.amazonaws.com",
errAssertionFns: []require.ErrorAssertionFunc{
// TODO(gavin): change this to [http.StatusRequestEntityTooLarge]
// after updating [trace.ErrorToCode].
@@ -426,52 +410,46 @@ func TestAWSSignerHandler(t *testing.T) {
},
},
{
- name: "AssumeRole success (shorter identity duration)",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-east-1"),
- })),
- request: assumeRoleRequest(2 * time.Hour),
- advanceClock: 10 * time.Minute,
- wantHost: "sts.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "sts",
- wantAuthCredRegion: "us-east-1",
- wantEventType: &events.AppSessionRequest{},
- verifySentRequest: verifyAssumeRoleDuration(50 * time.Minute), // 1h (suite default for identity) - 10m
+ name: "AssumeRole success (shorter identity duration)",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-east-1",
+ request: assumeRoleRequest(2 * time.Hour),
+ advanceClock: 10 * time.Minute,
+ wantHost: "sts.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "sts",
+ wantAuthCredRegion: "us-east-1",
+ wantEventType: &events.AppSessionRequest{},
+ verifySentRequest: verifyAssumeRoleDuration(50 * time.Minute), // 1h (suite default for identity) - 10m
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "AssumeRole success (shorter requested duration)",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-east-1"),
- })),
- request: assumeRoleRequest(32 * time.Minute),
- wantHost: "sts.amazonaws.com",
- wantAuthCredKeyID: "FAKEACCESSKEYID",
- wantAuthCredService: "sts",
- wantAuthCredRegion: "us-east-1",
- wantEventType: &events.AppSessionRequest{},
- verifySentRequest: verifyAssumeRoleDuration(32 * time.Minute), // matches the request
+ name: "AssumeRole success (shorter requested duration)",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-east-1",
+ request: assumeRoleRequest(32 * time.Minute),
+ wantHost: "sts.amazonaws.com",
+ wantAuthCredKeyID: "FAKEACCESSKEYID",
+ wantAuthCredService: "sts",
+ wantAuthCredRegion: "us-east-1",
+ wantEventType: &events.AppSessionRequest{},
+ verifySentRequest: verifyAssumeRoleDuration(32 * time.Minute), // matches the request
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
- name: "AssumeRole denied",
- app: consoleApp,
- awsClientSession: session.Must(session.NewSession(&aws.Config{
- Credentials: staticAWSCredentialsForClient,
- Region: aws.String("us-east-1"),
- })),
- request: assumeRoleRequest(2 * time.Hour),
- wantHost: "sts.amazonaws.com",
- advanceClock: 50 * time.Minute, // identity is expiring in 10m which is less than minimum
+ name: "AssumeRole denied",
+ app: consoleApp,
+ awsCredentialsProvider: staticAWSCredentialsForClient,
+ awsRegion: "us-east-1",
+ request: assumeRoleRequest(2 * time.Hour),
+ wantHost: "sts.amazonaws.com",
+ advanceClock: 50 * time.Minute, // identity is expiring in 10m which is less than minimum
errAssertionFns: []require.ErrorAssertionFunc{
// the request is 403 forbidden by Teleport, so the mock AWS handler won't be sent anything.
hasStatusCode(http.StatusForbidden),
@@ -498,7 +476,7 @@ func TestAWSSignerHandler(t *testing.T) {
// check that the signature is valid.
if !tc.skipVerifySignature {
err := awsutils.VerifyAWSSignature(r,
- credentialsv2.NewStaticCredentialsProvider(tc.wantAuthCredKeyID, "secret", "token"),
+ credentials.NewStaticCredentialsProvider(tc.wantAuthCredKeyID, "secret", "token"),
)
if !assert.NoError(t, err) {
http.Error(w, err.Error(), trace.ErrorToCode(err))
@@ -521,7 +499,7 @@ func TestAWSSignerHandler(t *testing.T) {
suite := createSuite(t, mockAwsHandler, tc.app, fakeClock, awsCfgProvider)
fakeClock.Advance(tc.advanceClock)
- err := tc.request(suite.URL, tc.awsClientSession, tc.wantHost)
+ err := tc.request(t.Context(), suite.URL, tc.awsRegion, tc.awsCredentialsProvider, tc.wantHost)
for _, assertFn := range tc.errAssertionFns {
assertFn(t, err)
}
@@ -625,8 +603,8 @@ func mustNewRequest(t *testing.T, method, url string, body io.Reader) *http.Requ
const assumedRoleKeyID = "assumedRoleKeyID"
var (
- staticAWSCredentialsForAssumedRole = credentials.NewStaticCredentials(assumedRoleKeyID, "assumedRoleKeySecret", "")
- staticAWSCredentialsForClient = credentials.NewStaticCredentials("fakeClientKeyID", "fakeClientSecret", "")
+ staticAWSCredentialsForAssumedRole = credentials.NewStaticCredentialsProvider(assumedRoleKeyID, "assumedRoleKeySecret", "")
+ staticAWSCredentialsForClient = credentials.NewStaticCredentialsProvider("fakeClientKeyID", "fakeClientSecret", "")
)
type suite struct {
diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go
index 67510a8c42c17..5f666738f5b88 100644
--- a/lib/srv/db/common/auth.go
+++ b/lib/srv/db/common/auth.go
@@ -69,11 +69,6 @@ const (
// azureVirtualMachineCacheTTL is the default TTL for Azure virtual machine
// cache entries.
azureVirtualMachineCacheTTL = 5 * time.Minute
-
- // emptyPayloadHash is the SHA-256 for an empty element (as in echo -n | sha256sum).
- // PresignHTTP requires the hash of the body, but when there is no body we hash the empty string.
- // https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
- emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
)
// Auth defines interface for creating auth tokens and TLS configurations.
@@ -1323,7 +1318,7 @@ func (r *awsRedisIAMTokenRequest) toSignedRequestURI(ctx context.Context) (strin
if err != nil {
return "", trace.Wrap(err)
}
- signedURI, _, err := signer.PresignHTTP(ctx, creds, req, emptyPayloadHash, r.serviceName, r.region, r.clock.Now())
+ signedURI, _, err := signer.PresignHTTP(ctx, creds, req, awsutils.EmptyPayloadHash, r.serviceName, r.region, r.clock.Now())
if err != nil {
return "", trace.Wrap(err)
}
diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go
index f62cbb737d5f4..cb479a20631bb 100644
--- a/lib/srv/server/ec2_watcher_test.go
+++ b/lib/srv/server/ec2_watcher_test.go
@@ -22,10 +22,9 @@ import (
"context"
"testing"
- awsv2 "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2"
ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types"
- "github.com/aws/aws-sdk-go/aws"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/require"
@@ -58,16 +57,16 @@ func (m *mockEC2Client) DescribeInstances(ctx context.Context, input *ec2.Descri
func instanceMatches(inst ec2types.Instance, filters []ec2types.Filter) bool {
allMatched := true
for _, filter := range filters {
- name := awsv2.ToString(filter.Name)
+ name := aws.ToString(filter.Name)
val := filter.Values[0]
if name == AWSInstanceStateName && inst.State.Name != ec2types.InstanceStateNameRunning {
return false
}
for _, tag := range inst.Tags {
- if awsv2.ToString(tag.Key) != name[4:] {
+ if aws.ToString(tag.Key) != name[4:] {
continue
}
- allMatched = allMatched && awsv2.ToString(tag.Value) != val
+ allMatched = allMatched && aws.ToString(tag.Value) != val
}
}
@@ -91,7 +90,7 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) {
},
expectedFilters: []ec2types.Filter{
{
- Name: awsv2.String(AWSInstanceStateName),
+ Name: aws.String(AWSInstanceStateName),
Values: []string{string(ec2types.InstanceStateNameRunning)},
},
},
@@ -105,11 +104,11 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) {
},
expectedFilters: []ec2types.Filter{
{
- Name: awsv2.String(AWSInstanceStateName),
+ Name: aws.String(AWSInstanceStateName),
Values: []string{string(ec2types.InstanceStateNameRunning)},
},
{
- Name: awsv2.String("tag:hello"),
+ Name: aws.String("tag:hello"),
Values: []string{"other"},
},
},
@@ -156,15 +155,15 @@ func TestEC2Watcher(t *testing.T) {
ctx := context.Background()
present := ec2types.Instance{
- InstanceId: awsv2.String("instance-present"),
+ InstanceId: aws.String("instance-present"),
Tags: []ec2types.Tag{
{
- Key: awsv2.String("teleport"),
- Value: awsv2.String("yes"),
+ Key: aws.String("teleport"),
+ Value: aws.String("yes"),
},
{
- Key: awsv2.String("Name"),
- Value: awsv2.String("Present"),
+ Key: aws.String("Name"),
+ Value: aws.String("Present"),
},
},
State: &ec2types.InstanceState{
@@ -172,20 +171,20 @@ func TestEC2Watcher(t *testing.T) {
},
}
presentOther := ec2types.Instance{
- InstanceId: awsv2.String("instance-present-2"),
+ InstanceId: aws.String("instance-present-2"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("dev"),
+ Key: aws.String("env"),
+ Value: aws.String("dev"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
},
}
presentForEICE := ec2types.Instance{
- InstanceId: awsv2.String("instance-present-3"),
+ InstanceId: aws.String("instance-present-3"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("with-eice"),
- Value: awsv2.String("please"),
+ Key: aws.String("with-eice"),
+ Value: aws.String("please"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
@@ -199,23 +198,23 @@ func TestEC2Watcher(t *testing.T) {
presentOther,
presentForEICE,
{
- InstanceId: awsv2.String("instance-absent"),
+ InstanceId: aws.String("instance-absent"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("prod"),
+ Key: aws.String("env"),
+ Value: aws.String("prod"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNameRunning,
},
},
{
- InstanceId: awsv2.String("instance-absent-3"),
+ InstanceId: aws.String("instance-absent-3"),
Tags: []ec2types.Tag{{
- Key: awsv2.String("env"),
- Value: awsv2.String("prod"),
+ Key: aws.String("env"),
+ Value: aws.String("prod"),
}, {
- Key: awsv2.String("teleport"),
- Value: awsv2.String("yes"),
+ Key: aws.String("teleport"),
+ Value: aws.String("yes"),
}},
State: &ec2types.InstanceState{
Name: ec2types.InstanceStateNamePending,
diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go
index 6b97e237ca861..106ccd8e7f24c 100644
--- a/lib/utils/aws/aws.go
+++ b/lib/utils/aws/aws.go
@@ -19,9 +19,9 @@
package aws
import (
- "bytes"
"context"
"crypto/sha1"
+ "crypto/sha256"
"encoding/hex"
"fmt"
"log/slog"
@@ -33,14 +33,12 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/aws/arn"
- "github.com/aws/aws-sdk-go/aws/credentials"
- v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
+ v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/gravitational/trace"
apievents "github.com/gravitational/teleport/api/types/events"
apiawsutils "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/utils"
- "github.com/gravitational/teleport/lib/utils/aws/migration"
)
const (
@@ -77,6 +75,11 @@ const (
// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_iam-quotas.html
MaxRoleSessionNameLength = 64
+ // EmptyPayloadHash is the SHA-256 for an empty element (as in echo -n | sha256sum).
+ // PresignHTTP requires the hash of the body, but when there is no body we hash the empty string.
+ // https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html
+ EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
+
iamServiceName = "iam"
)
@@ -195,7 +198,8 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider)
return trace.Wrap(err)
}
- reqCopy := req.Clone(context.Background())
+ ctx := context.Background()
+ reqCopy := req.Clone(ctx)
// Remove all the headers that are not present in awsCred.SignedHeaders.
filterHeaders(reqCopy, sigV4.SignedHeaders)
@@ -207,8 +211,13 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider)
return trace.BadParameter("%s", err)
}
- signer := NewSignerV2(credProvider, sigV4.Service)
- _, err = signer.Sign(reqCopy, bytes.NewReader(payload), sigV4.Service, sigV4.Region, t)
+ creds, err := credProvider.Retrieve(ctx)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ signer := NewSigner(sigV4.Service)
+ err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), sigV4.Service, sigV4.Region, t)
if err != nil {
return trace.Wrap(err)
}
@@ -226,23 +235,28 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider)
return nil
}
-// NewSignerV2 is a temporary AWS SDK migration helper.
-func NewSignerV2(provider aws.CredentialsProvider, signingServiceName string) *v4.Signer {
- return NewSigner(migration.NewCredentialsAdapter(provider), signingServiceName)
-}
-
// NewSigner creates a new V4 signer.
-func NewSigner(credentials *credentials.Credentials, signingServiceName string) *v4.Signer {
- options := func(s *v4.Signer) {
+func NewSigner(signingServiceName string) *v4.Signer {
+ return v4.NewSigner(func(opts *v4.SignerOptions) {
// s3 and s3control requests are signed with URL unescaped (found by
// searching "DisableURIPathEscaping" in "aws-sdk-go/service"). Both
// services use "s3" as signing name. See description of
// "DisableURIPathEscaping" for more details.
if signingServiceName == "s3" {
- s.DisableURIPathEscaping = true
+ opts.DisableURIPathEscaping = true
}
+ })
+}
+
+// GetV4PayloadHash returns the V4 signing payload hash.
+func GetV4PayloadHash(payload []byte) string {
+ if len(payload) == 0 {
+ return EmptyPayloadHash
}
- return v4.NewSigner(credentials, options)
+
+ hash := sha256.New()
+ hash.Write(payload)
+ return hex.EncodeToString(hash.Sum(nil))
}
// filterHeaders removes request headers that are not in the headers list and returns the removed header keys.
diff --git a/lib/utils/aws/credentials.go b/lib/utils/aws/credentials.go
deleted file mode 100644
index 8be99d898e5f7..0000000000000
--- a/lib/utils/aws/credentials.go
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package aws
-
-import (
- "context"
-
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/aws/endpoints"
- "github.com/aws/aws-sdk-go/aws/session"
- "github.com/gravitational/trace"
-
- "github.com/gravitational/teleport/lib/modules"
-)
-
-// AWSSessionProvider defines a function that creates an AWS Session.
-// It must use ambient credentials if Integration is empty.
-// It must use Integration credentials otherwise.
-type AWSSessionProvider func(ctx context.Context, region string, integration string) (*session.Session, error)
-
-// StaticAWSSessionProvider is a helper method that returns a static session.
-// Must not be used to provide sessions when using Integrations.
-func StaticAWSSessionProvider(awsSession *session.Session) AWSSessionProvider {
- return func(ctx context.Context, region, integration string) (*session.Session, error) {
- if integration != "" {
- return nil, trace.BadParameter("integration %q is not allowed to use static sessions", integration)
- }
- return awsSession, nil
- }
-}
-
-// SessionProviderUsingAmbientCredentials returns an AWS Session using ambient credentials.
-// This is in contrast with AWS Sessions that can be generated using an AWS OIDC Integration.
-func SessionProviderUsingAmbientCredentials() AWSSessionProvider {
- return func(ctx context.Context, region, integration string) (*session.Session, error) {
- if integration != "" {
- return nil, trace.BadParameter("integration %q is not allowed to use ambient sessions", integration)
- }
- useFIPSEndpoint := endpoints.FIPSEndpointStateUnset
- if modules.GetModules().IsBoringBinary() {
- useFIPSEndpoint = endpoints.FIPSEndpointStateEnabled
- }
- session, err := session.NewSessionWithOptions(session.Options{
- SharedConfigState: session.SharedConfigEnable,
- Config: aws.Config{
- UseFIPSEndpoint: useFIPSEndpoint,
- },
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- return session, nil
- }
-}
diff --git a/lib/utils/aws/s3.go b/lib/utils/aws/s3.go
index cd050dff34ce9..c6fcb0275c736 100644
--- a/lib/utils/aws/s3.go
+++ b/lib/utils/aws/s3.go
@@ -22,15 +22,12 @@ import (
"context"
"errors"
"io"
- "net/http"
"strings"
awsv2 "github.com/aws/aws-sdk-go-v2/aws"
managerv2 "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
s3v2 "github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
- "github.com/aws/aws-sdk-go/aws/awserr"
- "github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/smithy-go"
"github.com/gravitational/trace"
)
@@ -42,25 +39,6 @@ func ConvertS3Error(err error) error {
return nil
}
- // SDK v1 errors:
- var rerr awserr.RequestFailure
- if errors.As(err, &rerr) && rerr.StatusCode() == http.StatusForbidden {
- return trace.AccessDenied("%s", rerr.Message())
- }
-
- var aerr awserr.Error
- if errors.As(err, &aerr) {
- switch aerr.Code() {
- case s3.ErrCodeNoSuchKey, s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchUpload, "NotFound":
- return trace.NotFound("%s", aerr)
- case s3.ErrCodeBucketAlreadyExists, s3.ErrCodeBucketAlreadyOwnedByYou:
- return trace.AlreadyExists("%s", aerr)
- default:
- return trace.BadParameter("%s", aerr)
- }
- }
-
- // SDK v2 errors:
var noSuchKey *s3types.NoSuchKey
if errors.As(err, &noSuchKey) {
return trace.NotFound("%s", noSuchKey)
diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go
index 31d29532c20d8..385ccfdcbae79 100644
--- a/lib/utils/aws/signing.go
+++ b/lib/utils/aws/signing.go
@@ -19,10 +19,10 @@
package aws
import (
- "bytes"
"context"
"io"
"net/http"
+ "time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/gravitational/trace"
@@ -94,8 +94,14 @@ func SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (*
// 100-continue" headers without being signed, otherwise the Athena service
// would reject the requests.
unsignedHeaders := removeUnsignedHeaders(reqCopy)
- signer := NewSignerV2(signCtx.Credentials, signCtx.SigningName)
- _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, signCtx.Clock.Now())
+
+ creds, err := signCtx.Credentials.Retrieve(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ signer := NewSigner(signCtx.SigningName)
+ err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), signCtx.SigningName, signCtx.SigningRegion, time.Now())
if err != nil {
return nil, trace.Wrap(err)
}
diff --git a/lib/utils/aws/xml.go b/lib/utils/aws/xml.go
index e88ca680a5934..69e9c5b84c780 100644
--- a/lib/utils/aws/xml.go
+++ b/lib/utils/aws/xml.go
@@ -22,62 +22,52 @@ import (
"bytes"
"encoding/xml"
- "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil"
+ smithyxml "github.com/aws/smithy-go/encoding/xml"
+
"github.com/gravitational/trace"
)
// IsXMLOfLocalName returns true if the root XML has the provided (local) name.
func IsXMLOfLocalName(data []byte, wantLocalName string) bool {
- var name xml.Name
- if err := xml.Unmarshal(data, &name); err == nil {
- return wantLocalName == name.Local
+ st, err := smithyxml.FetchRootElement(xml.NewDecoder(bytes.NewReader(data)))
+ if err == nil && st.Name.Local == wantLocalName {
+ return true
}
+
return false
}
// UnmarshalXMLChildNode decodes the XML-encoded data and stores the child node
// with the specified name to v, where v is a pointer to an AWS SDK v1 struct.
func UnmarshalXMLChildNode(v interface{}, data []byte, childName string) error {
- return trace.Wrap(xmlutil.UnmarshalXML(v, xml.NewDecoder(bytes.NewReader(data)), childName))
+ decoder := xml.NewDecoder(bytes.NewReader(data))
+ st, err := smithyxml.FetchRootElement(decoder)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ nodeDecoder := smithyxml.WrapNodeDecoder(decoder, st)
+ childElem, err := nodeDecoder.GetElement(childName)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ return trace.Wrap(decoder.DecodeElement(v, &childElem))
}
// MarshalXML marshals the provided root name and a map of children in XML with
// default indent (prefix "", indent " ").
-func MarshalXML(rootName xml.Name, children map[string]any) ([]byte, error) {
+func MarshalXML(root string, namespace string, v any) ([]byte, error) {
var buf bytes.Buffer
encoder := xml.NewEncoder(&buf)
encoder.Indent("", " ")
-
- err := encodeXMLNode(encoder, rootName, func() error {
- for childName, childValue := range children {
- if err := encodeXMLNodeAWSSDKV1(encoder, childName, childValue); err != nil {
- return trace.Wrap(err)
- }
- }
- return nil
+ err := encoder.EncodeElement(v, xml.StartElement{
+ Name: xml.Name{Local: root},
+ Attr: []xml.Attr{
+ {Name: xml.Name{Local: "xmlns"}, Value: namespace},
+ },
})
if err != nil {
return nil, trace.Wrap(err)
}
- if err := trace.Wrap(encoder.Flush()); err != nil {
- return nil, trace.Wrap(err)
- }
return buf.Bytes(), nil
}
-
-func encodeXMLNode(encoder *xml.Encoder, name xml.Name, encodeChildren func() error) error {
- startElement := xml.StartElement{Name: name}
- if err := encoder.EncodeToken(startElement); err != nil {
- return trace.Wrap(err)
- }
- if err := encodeChildren(); err != nil {
- return trace.Wrap(err)
- }
- return trace.Wrap(encoder.EncodeToken(startElement.End()))
-}
-
-func encodeXMLNodeAWSSDKV1(encoder *xml.Encoder, name string, v any) error {
- return encodeXMLNode(encoder, xml.Name{Local: name}, func() error {
- return trace.Wrap(xmlutil.BuildXML(v, encoder))
- })
-}
diff --git a/lib/utils/aws/xml_test.go b/lib/utils/aws/xml_test.go
index f2553e6a4cb9c..b4bb39d6f14ad 100644
--- a/lib/utils/aws/xml_test.go
+++ b/lib/utils/aws/xml_test.go
@@ -19,15 +19,13 @@
package aws
import (
- "encoding/xml"
- "net/http"
"strings"
"testing"
"time"
- "github.com/aws/aws-sdk-go/aws"
- "github.com/aws/aws-sdk-go/private/protocol"
- "github.com/aws/aws-sdk-go/service/sts"
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/sts"
+ ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types"
"github.com/stretchr/testify/require"
)
@@ -40,10 +38,10 @@ func TestIsXMLOfLocalName(t *testing.T) {
func TestUnmarshalXMLChildNode(t *testing.T) {
want := sts.AssumeRoleOutput{
- AssumedRoleUser: &sts.AssumedRoleUser{
+ AssumedRoleUser: &ststypes.AssumedRoleUser{
Arn: aws.String("some-arn"),
},
- Credentials: &sts.Credentials{
+ Credentials: &ststypes.Credentials{
AccessKeyId: aws.String("some-access-key-id"),
SecretAccessKey: aws.String("some-secret-access-key"),
SessionToken: aws.String("some-session-token"),
@@ -75,49 +73,22 @@ func TestUnmarshalXMLChildNode(t *testing.T) {
}
func TestMarshalXMLIndent(t *testing.T) {
- data, err := MarshalXML(
- xml.Name{
- Local: "AssumeRoleResponse",
- Space: "https://sts.amazonaws.com/doc/2011-06-15/",
- },
- map[string]any{
- "AssumeRoleResult": sts.AssumeRoleOutput{
- AssumedRoleUser: &sts.AssumedRoleUser{
- Arn: aws.String("some-arn"),
- },
- Credentials: &sts.Credentials{
- AccessKeyId: aws.String("some-access-key-id"),
- SecretAccessKey: aws.String("some-secret-access-key"),
- SessionToken: aws.String("some-session-token"),
- Expiration: aws.Time(time.Unix(1234567890, 0).UTC()),
- },
- },
- "ResponseMetadata": protocol.ResponseMetadata{
- RequestID: "some-request-id",
- StatusCode: http.StatusOK,
- },
- },
- )
+ simpleAssumeRole := struct {
+ Test string
+ Encoding time.Time
+ }{
+ Test: "test",
+ Encoding: time.Unix(1234567890, 0).UTC(),
+ }
+
+ data, err := MarshalXML("AssumeRoleResponse", "https://sts.amazonaws.com/doc/2011-06-15/", simpleAssumeRole)
require.NoError(t, err)
// Nodes are not sorted. Use ElementsMatch to ensure each line is present.
require.ElementsMatch(t, []string{
``,
- ` `,
- ` `,
- ` some-secret-access-key`,
- ` some-session-token`,
- ` some-access-key-id`,
- ` 2009-02-13T23:31:30Z`,
- ` `,
- ` `,
- ` some-arn`,
- ` `,
- ` `,
- ` `,
- ` 200`,
- ` some-request-id`,
- ` `,
+ ` test`,
+ ` 2009-02-13T23:31:30Z`,
``,
}, strings.Split(string(data), "\n"))
}