From c91d89a3524207f9f33d9b83e4e878fb662d55b8 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Tue, 11 Mar 2025 21:23:08 -0300 Subject: [PATCH 1/4] refactor: remove references to aws sdk v1 --- go.mod | 1 + go.sum | 2 + lib/events/fips.go | 12 +- .../externalauditstorage/configurator_test.go | 2 +- lib/srv/alpnproxy/aws_local_proxy_test.go | 104 +++-- lib/srv/app/aws/endpoints_test.go | 12 +- lib/srv/app/aws/handler.go | 7 +- lib/srv/app/aws/handler_test.go | 416 +++++++++--------- lib/srv/db/common/auth.go | 7 +- lib/srv/server/ec2_watcher_test.go | 53 ++- lib/utils/aws/aws.go | 44 +- lib/utils/aws/credentials.go | 71 --- lib/utils/aws/signing.go | 12 +- lib/utils/aws/xml.go | 58 +-- lib/utils/aws/xml_test.go | 61 +-- 15 files changed, 382 insertions(+), 480 deletions(-) delete mode 100644 lib/utils/aws/credentials.go diff --git a/go.mod b/go.mod index 177fceb4fcf1d..50716677616e2 100644 --- a/go.mod +++ b/go.mod @@ -293,6 +293,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.10.15 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.15 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 // indirect + github.com/aws/aws-sdk-go-v2/service/lambda v1.70.1 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.25.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.29.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect diff --git a/go.sum b/go.sum index 9c04612ef6dec..29e702d79fba5 100644 --- a/go.sum +++ b/go.sum @@ -926,6 +926,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15 h1:moLQUoVq91Liq github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.15/go.mod h1:ZH34PJUc8ApjBIfgQCFvkWcUDBtl/WTD+uiYHjd8igA= github.com/aws/aws-sdk-go-v2/service/kms v1.38.0 h1:+2/0Cq0R/audJhwM1GpJMg8X1TTrMKDFRLO5RMaNRU0= github.com/aws/aws-sdk-go-v2/service/kms v1.38.0/go.mod h1:cQn6tAF77Di6m4huxovNM7NVAozWTZLsDRp9t8Z/WYk= +github.com/aws/aws-sdk-go-v2/service/lambda v1.70.1 h1:EabaKQAptxXAeSL0sXKqfupPe/CpH965wqoloUK0aMM= +github.com/aws/aws-sdk-go-v2/service/lambda v1.70.1/go.mod h1:c27kk10S36lBYgbG1jR3opn4OAS5Y/4wjJa1GiHK/X4= github.com/aws/aws-sdk-go-v2/service/memorydb v1.26.0 h1:nO9RCZnfAIF5q43IDLWtf7vu/l16RKzeTkv5GObkyME= github.com/aws/aws-sdk-go-v2/service/memorydb v1.26.0/go.mod h1:pfuDC5zBwunXdE44WT1PRbtzuXWGohKFcFLtv+ezI6k= github.com/aws/aws-sdk-go-v2/service/opensearch v1.46.0 h1:eR65kYpNlKpGkkvg+A83hc0hpk2CHappaz1JAUCcxVs= diff --git a/lib/events/fips.go b/lib/events/fips.go index b00a1910df111..f0feb6d58958e 100644 --- a/lib/events/fips.go +++ b/lib/events/fips.go @@ -19,7 +19,7 @@ package events import ( - "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/gravitational/teleport/api/types" ) @@ -31,14 +31,14 @@ const ( ) var ( - fipsToAWS = map[types.ClusterAuditConfigSpecV2_FIPSEndpointState]endpoints.FIPSEndpointState{ - types.ClusterAuditConfigSpecV2_FIPS_UNSET: endpoints.FIPSEndpointStateUnset, - types.ClusterAuditConfigSpecV2_FIPS_ENABLED: endpoints.FIPSEndpointStateEnabled, - types.ClusterAuditConfigSpecV2_FIPS_DISABLED: endpoints.FIPSEndpointStateDisabled, + fipsToAWS = map[types.ClusterAuditConfigSpecV2_FIPSEndpointState]aws.FIPSEndpointState{ + types.ClusterAuditConfigSpecV2_FIPS_UNSET: aws.FIPSEndpointStateUnset, + types.ClusterAuditConfigSpecV2_FIPS_ENABLED: aws.FIPSEndpointStateEnabled, + types.ClusterAuditConfigSpecV2_FIPS_DISABLED: aws.FIPSEndpointStateDisabled, } ) // FIPSProtoStateToAWSState converts a FIPS proto state to an aws endpoints.FIPSEndpointState -func FIPSProtoStateToAWSState(state types.ClusterAuditConfigSpecV2_FIPSEndpointState) endpoints.FIPSEndpointState { +func FIPSProtoStateToAWSState(state types.ClusterAuditConfigSpecV2_FIPSEndpointState) aws.FIPSEndpointState { return fipsToAWS[state] } diff --git a/lib/integrations/externalauditstorage/configurator_test.go b/lib/integrations/externalauditstorage/configurator_test.go index 631fae9337aed..df845e15a1e15 100644 --- a/lib/integrations/externalauditstorage/configurator_test.go +++ b/lib/integrations/externalauditstorage/configurator_test.go @@ -25,9 +25,9 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sts" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" - "github.com/aws/aws-sdk-go/aws" "github.com/google/uuid" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go index 1e01cfed5606d..75591e1062f98 100644 --- a/lib/srv/alpnproxy/aws_local_proxy_test.go +++ b/lib/srv/alpnproxy/aws_local_proxy_test.go @@ -19,8 +19,6 @@ package alpnproxy import ( - "context" - "encoding/xml" "net/http" "net/http/httptest" "testing" @@ -28,7 +26,6 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go-v2/service/sts" ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/stretchr/testify/require" @@ -39,8 +36,8 @@ func TestAWSAccessMiddleware(t *testing.T) { t.Parallel() assumedRoleARN := "arn:aws:sts::123456789012:assumed-role/role-name/role-session-name" - localProxyCred := credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", "") - assumedRoleCred := credentials.NewStaticCredentialsProvider("assumed-role", "assumed-role-secret", "assumed-role-token") + localCred := aws.Credentials{AccessKeyID: "local-proxy", SecretAccessKey: "local-proxy-secret"} + assumedRoleCred := aws.Credentials{AccessKeyID: "assumed-role", SecretAccessKey: "assumed-role-secret", SessionToken: "assumed-role-token"} m := &AWSAccessMiddleware{ AWSCredentialsProvider: credentials.NewStaticCredentialsProvider("local-proxy", "local-proxy-secret", ""), @@ -48,11 +45,10 @@ func TestAWSAccessMiddleware(t *testing.T) { require.NoError(t, m.CheckAndSetDefaults()) stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) - - awsutils.NewSignerV2(localProxyCred, "sts").Sign(stsRequestByLocalProxyCred, nil, "sts", "us-west-1", time.Now()) + awsutils.NewSignerV2("sts").SignHTTP(t.Context(), localCred, stsRequestByLocalProxyCred, awsutils.EmptyPayloadHash, "sts", "us-west-1", time.Now()) requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) - awsutils.NewSignerV2(assumedRoleCred, "s3").Sign(requestByAssumedRole, nil, "s3", "us-west-1", time.Now()) + awsutils.NewSignerV2("s3").SignHTTP(t.Context(), assumedRoleCred, requestByAssumedRole, awsutils.EmptyPayloadHash, "s3", "us-west-1", time.Now()) t.Run("request no authorization", func(t *testing.T) { recorder := httptest.NewRecorder() @@ -99,34 +95,60 @@ func TestAWSAccessMiddleware(t *testing.T) { }) } -func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsProvider) *http.Response { - t.Helper() +// IdentityResult represents the identitiy result of an AWS response. +type IdentityResult struct { + ARN string `xml:"Arn"` +} - credValue, err := provider.Retrieve(context.Background()) - require.NoError(t, err) +// ResponseMetadata contains the metadata of a AWS response. +type ResponseMetadata struct { + RequestID string `xml:"RequestID"` + StatusCode int `xml:"StatusCode"` +} - body, err := awsutils.MarshalXML( - xml.Name{ - Local: "AssumeRoleResponse", - Space: "https://sts.amazonaws.com/doc/2011-06-15/", - }, - map[string]any{ - "AssumeRoleResult": sts.AssumeRoleOutput{ - AssumedRoleUser: &ststypes.AssumedRoleUser{ - Arn: aws.String(roleARN), - }, - Credentials: &ststypes.Credentials{ - AccessKeyId: aws.String(credValue.AccessKeyID), - SecretAccessKey: aws.String(credValue.SecretAccessKey), - SessionToken: aws.String(credValue.SessionToken), - }, +// AssumeRoleResult contains the assume role result. +type AssumeRoleResult struct { + // AssumedRoleUser is the assumed user. + AssumedRoleUser IdentityResult `xml:"AssumedRoleUser"` + // Credentials is the generated credentials. + Credentials ststypes.Credentials `xml:"Credentials"` +} + +// AssumeRoleResponse is the response of assume role. +type AssumeRoleResponse struct { + // AssumeRoleResult is the resulting response from assume role. + AssumeRoleResult AssumeRoleResult `xml:"AssumeRoleResult"` + // Response is the response metadata. + Response ResponseMetadata `xml:"ResponseMetadata"` +} + +// GetCallerIdentityResponse is the response of get caller identity call. +type GetCallerIdentityResponse struct { + // AssumeRoleResult is the resulting response from assume role. + GetCallerIdentityResult IdentityResult `xml:"GetCallerIdentityResult"` + // Response is the response metadata. + Response ResponseMetadata `xml:"ResponseMetadata"` +} + +func assumeRoleResponse(t *testing.T, roleARN string, creds aws.Credentials) *http.Response { + t.Helper() + + body, err := awsutils.MarshalXML("AssumeRoleResponse", "https://sts.amazonaws.com/doc/2011-06-15/", AssumeRoleResponse{ + AssumeRoleResult: AssumeRoleResult{ + AssumedRoleUser: IdentityResult{ + ARN: roleARN, }, - "ResponseMetadata": map[string]any{ - "StatusCode": http.StatusOK, - "RequestID": "22222222-3333-3333-3333-333333333333", + Credentials: ststypes.Credentials{ + AccessKeyId: aws.String(creds.AccessKeyID), + SecretAccessKey: aws.String(creds.SecretAccessKey), + SessionToken: aws.String(creds.SessionToken), }, }, - ) + Response: ResponseMetadata{ + StatusCode: http.StatusOK, + RequestID: "22222222-3333-3333-3333-333333333333", + }, + }) require.NoError(t, err) return fakeHTTPResponse(http.StatusOK, body) } @@ -134,21 +156,15 @@ func assumeRoleResponse(t *testing.T, roleARN string, provider aws.CredentialsPr func getCallerIdentityResponse(t *testing.T, roleARN string) *http.Response { t.Helper() - body, err := awsutils.MarshalXML( - xml.Name{ - Local: "GetCallerIdentityResponse", - Space: "https://sts.amazonaws.com/doc/2011-06-15/", + body, err := awsutils.MarshalXML("GetCallerIdentityResponse", "https://sts.amazonaws.com/doc/2011-06-15/", GetCallerIdentityResponse{ + GetCallerIdentityResult: IdentityResult{ + ARN: roleARN, }, - map[string]any{ - "GetCallerIdentityResult": sts.GetCallerIdentityOutput{ - Arn: aws.String(roleARN), - }, - "ResponseMetadata": map[string]any{ - "StatusCode": http.StatusOK, - "RequestID": "22222222-3333-3333-3333-333333333333", - }, + Response: ResponseMetadata{ + StatusCode: http.StatusOK, + RequestID: "22222222-3333-3333-3333-333333333333", }, - ) + }) require.NoError(t, err) return fakeHTTPResponse(http.StatusOK, body) } diff --git a/lib/srv/app/aws/endpoints_test.go b/lib/srv/app/aws/endpoints_test.go index 2690482b71ad1..7a1116be35430 100644 --- a/lib/srv/app/aws/endpoints_test.go +++ b/lib/srv/app/aws/endpoints_test.go @@ -19,20 +19,20 @@ package aws import ( - "bytes" "net/http" "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/stretchr/testify/require" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) func TestResolveEndpoints(t *testing.T) { - signer := v4.NewSigner(credentials.NewStaticCredentials("fakeClientKeyID", "fakeClientSecret", "")) + creds := aws.Credentials{AccessKeyID: "fakeClientKeyID", SecretAccessKey: "fakeClientSecret"} + signer := v4.NewSigner() region := "us-east-1" now := time.Now() @@ -40,7 +40,7 @@ func TestResolveEndpoints(t *testing.T) { req, err := http.NewRequest("GET", "http://localhost", nil) require.NoError(t, err) - _, err = signer.Sign(req, bytes.NewReader(nil), "ecr", "us-east-1", now) + err = signer.SignHTTP(t.Context(), creds, req, awsutils.EmptyPayloadHash, "ecr", "us-east-1", now) require.NoError(t, err) _, err = resolveEndpoint(req, awsutils.AuthorizationHeader) @@ -52,7 +52,7 @@ func TestResolveEndpoints(t *testing.T) { require.NoError(t, err) req.Header.Set("X-Forwarded-Host", "some-service.us-east-1.amazonaws.com") - _, err = signer.Sign(req, bytes.NewReader(nil), "some-service", region, now) + err = signer.SignHTTP(t.Context(), creds, req, awsutils.EmptyPayloadHash, "some-service", region, now) require.NoError(t, err) endpoint, err := resolveEndpoint(req, awsutils.AuthorizationHeader) diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 2ad691537fa4a..476ad04e42c1c 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -165,6 +165,11 @@ func (s *signerHandler) serveCommonRequest(sessCtx *common.SessionContext, w htt return trace.Wrap(err) } + reqCloneForAudit, err := cloneRequest(unsignedReq) + if err != nil { + return trace.Wrap(err) + } + awsCfg, err := s.AWSConfigProvider.GetConfig(s.closeContext, re.SigningRegion, awsconfig.WithDetailedAssumeRole(awsconfig.AssumeRole{ RoleARN: sessCtx.Identity.RouteToApp.AWSRoleARN, @@ -189,7 +194,7 @@ func (s *signerHandler) serveCommonRequest(sessCtx *common.SessionContext, w htt } recorder := httplib.NewResponseStatusRecorder(w) s.fwd.ServeHTTP(recorder, signedReq) - s.emitAudit(sessCtx, unsignedReq, uint32(recorder.Status()), re) + s.emitAudit(sessCtx, reqCloneForAudit, uint32(recorder.Status()), re) return nil } diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 182de8945b7f8..e51210582db35 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -31,16 +31,13 @@ import ( "testing" "time" - credentialsv2 "github.com/aws/aws-sdk-go-v2/credentials" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/dynamodb" - "github.com/aws/aws-sdk-go/service/lambda" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + transporthttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/dynamodb" + "github.com/aws/aws-sdk-go-v2/service/lambda" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -66,47 +63,51 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -type makeRequest func(url string, provider client.ConfigProvider, awsHost string) error +type makeRequest func(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error -func s3Request(url string, provider client.ConfigProvider, awsHost string) error { - return s3RequestWithTransport(url, provider, &requestByHTTPSProxy{xForwardedHost: awsHost}) +func s3Request(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { + return s3RequestWithTransport(ctx, url, region, provider, &requestByHTTPSProxy{xForwardedHost: awsHost}) } -func s3RequestByAssumedRole(url string, provider client.ConfigProvider, awsHost string) error { - return s3RequestWithTransport(url, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost}) +func s3RequestByAssumedRole(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { + return s3RequestWithTransport(ctx, url, region, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost}) } -func s3RequestWithTransport(url string, provider client.ConfigProvider, transport http.RoundTripper) error { - s3Client := s3.New(provider, &aws.Config{ - Endpoint: &url, - MaxRetries: aws.Int(0), +func s3RequestWithTransport(ctx context.Context, url string, region string, provider aws.CredentialsProvider, transport http.RoundTripper) error { + s3Client := s3.New(s3.Options{ + Credentials: provider, + BaseEndpoint: &url, + Region: region, + RetryMaxAttempts: 0, HTTPClient: &http.Client{ Transport: transport, Timeout: 5 * time.Second, }, }) - _, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) + _, err := s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) return err } -func dynamoRequest(url string, provider client.ConfigProvider, awsHost string) error { - return dynamoRequestWithTransport(url, provider, &requestByHTTPSProxy{xForwardedHost: awsHost}) +func dynamoRequest(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { + return dynamoRequestWithTransport(ctx, url, region, provider, &requestByHTTPSProxy{xForwardedHost: awsHost}) } -func dynamoRequestByAssumedRole(url string, provider client.ConfigProvider, awsHost string) error { - return dynamoRequestWithTransport(url, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost}) +func dynamoRequestByAssumedRole(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { + return dynamoRequestWithTransport(ctx, url, region, provider, &requestByAssumedRoleTransport{xForwardedHost: awsHost}) } -func dynamoRequestWithTransport(url string, provider client.ConfigProvider, transport http.RoundTripper) error { - dynamoClient := dynamodb.New(provider, &aws.Config{ - Endpoint: &url, - MaxRetries: aws.Int(0), +func dynamoRequestWithTransport(ctx context.Context, url string, region string, provider aws.CredentialsProvider, transport http.RoundTripper) error { + dynamoClient := dynamodb.New(dynamodb.Options{ + Credentials: provider, + BaseEndpoint: &url, + Region: region, + RetryMaxAttempts: 0, HTTPClient: &http.Client{ Transport: transport, Timeout: 5 * time.Second, }, }) - _, err := dynamoClient.Scan(&dynamodb.ScanInput{ + _, err := dynamoClient.Scan(ctx, &dynamodb.ScanInput{ TableName: aws.String("test-table"), }) return err @@ -116,30 +117,32 @@ func dynamoRequestWithTransport(url string, provider client.ConfigProvider, tran // size. Use a 1MB limit instead of the actual 70MB limit. const maxTestHTTPRequestBodySize = 1 << 20 -func maxSizeExceededRequest(url string, provider client.ConfigProvider, awsHost string) error { +func maxSizeExceededRequest(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { // fake an upload that's too large payload := strings.Repeat("x", maxTestHTTPRequestBodySize) - return lambdaRequestWithPayload(url, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost}) + return lambdaRequestWithPayload(ctx, url, region, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost}) } -func lambdaRequest(url string, provider client.ConfigProvider, awsHost string) error { +func lambdaRequest(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { // fake a zip file with 70% of the max limit. Lambda will base64 encode it, // which bloats it up, and our proxy should still handle it. const size = (maxTestHTTPRequestBodySize * 7) / 10 payload := strings.Repeat("x", size) - return lambdaRequestWithPayload(url, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost}) + return lambdaRequestWithPayload(ctx, url, region, provider, payload, &requestByHTTPSProxy{xForwardedHost: awsHost}) } -func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payload string, transport http.RoundTripper) error { - lambdaClient := lambda.New(provider, &aws.Config{ - Endpoint: &url, - MaxRetries: aws.Int(0), +func lambdaRequestWithPayload(ctx context.Context, url string, region string, provider aws.CredentialsProvider, payload string, transport http.RoundTripper) error { + lambdaClient := lambda.New(lambda.Options{ + Credentials: provider, + BaseEndpoint: &url, + Region: region, + RetryMaxAttempts: 0, HTTPClient: &http.Client{ Timeout: 5 * time.Second, Transport: transport, }, }) - _, err := lambdaClient.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{ + _, err := lambdaClient.UpdateFunctionCode(ctx, &lambda.UpdateFunctionCodeInput{ FunctionName: aws.String("fakeFunc"), ZipFile: []byte(payload), }) @@ -147,17 +150,19 @@ func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payloa } func assumeRoleRequest(requestDuration time.Duration) makeRequest { - return func(url string, provider client.ConfigProvider, awsHost string) error { - stsClient := stsutils.NewV1(provider, &aws.Config{ - Endpoint: &url, - MaxRetries: aws.Int(0), + return func(ctx context.Context, url string, region string, provider aws.CredentialsProvider, awsHost string) error { + stsClient := stsutils.NewFromConfig(aws.Config{ + Credentials: provider, + BaseEndpoint: &url, + Region: region, + RetryMaxAttempts: 0, HTTPClient: &http.Client{ Timeout: 5 * time.Second, Transport: &requestByHTTPSProxy{xForwardedHost: awsHost}, }, }) - _, err := stsClient.AssumeRole(&sts.AssumeRoleInput{ - DurationSeconds: aws.Int64(int64(requestDuration.Seconds())), + _, err := stsClient.AssumeRole(ctx, &sts.AssumeRoleInput{ + DurationSeconds: aws.Int32(int32(requestDuration.Seconds())), RoleSessionName: aws.String("test-session"), RoleArn: aws.String("arn:aws:iam::123456789012:role/test-role"), }) @@ -191,9 +196,9 @@ func (r requestByAssumedRoleTransport) RoundTrip(req *http.Request) (*http.Respo func hasStatusCode(wantStatusCode int) require.ErrorAssertionFunc { return func(t require.TestingT, err error, msgAndArgs ...interface{}) { - var apiErr awserr.RequestFailure - require.ErrorAs(t, err, &apiErr, msgAndArgs...) - require.Equal(t, wantStatusCode, apiErr.StatusCode(), msgAndArgs...) + var respErr *transporthttp.ResponseError + require.ErrorAs(t, err, &respErr, msgAndArgs...) + require.Equal(t, wantStatusCode, respErr.Response.StatusCode, msgAndArgs...) } } @@ -224,47 +229,44 @@ func TestAWSSignerHandler(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - app types.Application - awsClientSession *session.Session - awsConfigProvider awsconfig.Provider - request makeRequest - advanceClock time.Duration - wantHost string - wantAuthCredService string - wantAuthCredRegion string - wantAuthCredKeyID string - wantEventType events.AuditEvent - wantAssumedRole string - skipVerifySignature bool - verifySentRequest func(*testing.T, *http.Request) - errAssertionFns []require.ErrorAssertionFunc + name string + app types.Application + awsCredentialsProvider aws.CredentialsProvider + awsRegion string + awsConfigProvider awsconfig.Provider + request makeRequest + advanceClock time.Duration + wantHost string + wantAuthCredService string + wantAuthCredRegion string + wantAuthCredKeyID string + wantEventType events.AuditEvent + wantAssumedRole string + skipVerifySignature bool + verifySentRequest func(*testing.T, *http.Request) + errAssertionFns []require.ErrorAssertionFunc }{ { - name: "s3 access", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-west-2"), - })), - request: s3Request, - wantHost: "s3.us-west-2.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "s3", - wantAuthCredRegion: "us-west-2", - wantEventType: &events.AppSessionRequest{}, + name: "s3 access", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-west-2", + request: s3Request, + wantHost: "s3.us-west-2.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "s3", + wantAuthCredRegion: "us-west-2", + wantEventType: &events.AppSessionRequest{}, errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "s3 access with integration", - app: consoleAppWithIntegration, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-west-2"), - })), - request: s3Request, + name: "s3 access with integration", + app: consoleAppWithIntegration, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-west-2", + request: s3Request, awsConfigProvider: &mocks.AWSConfigProvider{ OIDCIntegrationClient: &mocks.FakeOIDCIntegrationClient{ Integration: awsOIDCIntegration, @@ -281,144 +283,126 @@ func TestAWSSignerHandler(t *testing.T) { }, }, { - name: "s3 access with different region", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-west-1"), - })), - request: s3Request, - wantHost: "s3.us-west-1.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "s3", - wantAuthCredRegion: "us-west-1", - wantEventType: &events.AppSessionRequest{}, + name: "s3 access with different region", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-west-1", + request: s3Request, + wantHost: "s3.us-west-1.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "s3", + wantAuthCredRegion: "us-west-1", + wantEventType: &events.AppSessionRequest{}, errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "s3 access missing credentials", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: credentials.AnonymousCredentials, - Region: aws.String("us-west-1"), - })), - request: s3Request, + name: "s3 access missing credentials", + app: consoleApp, + awsCredentialsProvider: aws.AnonymousCredentials{}, + awsRegion: "us-west-1", + request: s3Request, errAssertionFns: []require.ErrorAssertionFunc{ hasStatusCode(http.StatusBadRequest), }, }, { - name: "s3 access by assumed role", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForAssumedRole, - Region: aws.String("us-west-2"), - })), - request: s3RequestByAssumedRole, - wantHost: "s3.us-west-2.amazonaws.com", - wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID - wantAuthCredService: "s3", - wantAuthCredRegion: "us-west-2", - wantEventType: &events.AppSessionRequest{}, - wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit - skipVerifySignature: true, // not re-signing + name: "s3 access by assumed role", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForAssumedRole, + awsRegion: "us-west-2", + request: s3RequestByAssumedRole, + wantHost: "s3.us-west-2.amazonaws.com", + wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID + wantAuthCredService: "s3", + wantAuthCredRegion: "us-west-2", + wantEventType: &events.AppSessionRequest{}, + wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit + skipVerifySignature: true, // not re-signing errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "DynamoDB access", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-east-1"), - })), - request: dynamoRequest, - wantHost: "dynamodb.us-east-1.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "dynamodb", - wantAuthCredRegion: "us-east-1", - wantEventType: &events.AppSessionDynamoDBRequest{}, + name: "DynamoDB access", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-east-1", + request: dynamoRequest, + wantHost: "dynamodb.us-east-1.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "dynamodb", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionDynamoDBRequest{}, errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "DynamoDB access with different region", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-west-1"), - })), - request: dynamoRequest, - wantHost: "dynamodb.us-west-1.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "dynamodb", - wantAuthCredRegion: "us-west-1", - wantEventType: &events.AppSessionDynamoDBRequest{}, + name: "DynamoDB access with different region", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-west-1", + request: dynamoRequest, + wantHost: "dynamodb.us-west-1.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "dynamodb", + wantAuthCredRegion: "us-west-1", + wantEventType: &events.AppSessionDynamoDBRequest{}, errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "DynamoDB access missing credentials", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: credentials.AnonymousCredentials, - Region: aws.String("us-west-1"), - })), - request: dynamoRequest, + name: "DynamoDB access missing credentials", + app: consoleApp, + awsCredentialsProvider: aws.AnonymousCredentials{}, + awsRegion: "us-west-1", + request: dynamoRequest, errAssertionFns: []require.ErrorAssertionFunc{ hasStatusCode(http.StatusBadRequest), }, }, { - name: "DynamoDB access by assumed role", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForAssumedRole, - Region: aws.String("us-east-1"), - })), - request: dynamoRequestByAssumedRole, - wantHost: "dynamodb.us-east-1.amazonaws.com", - wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID - wantAuthCredService: "dynamodb", - wantAuthCredRegion: "us-east-1", - wantEventType: &events.AppSessionDynamoDBRequest{}, - wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit - skipVerifySignature: true, // not re-signing + name: "DynamoDB access by assumed role", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForAssumedRole, + awsRegion: "us-east-1", + request: dynamoRequestByAssumedRole, + wantHost: "dynamodb.us-east-1.amazonaws.com", + wantAuthCredKeyID: assumedRoleKeyID, // not using service's access key ID + wantAuthCredService: "dynamodb", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionDynamoDBRequest{}, + wantAssumedRole: fakeAssumedRoleARN, // verifies assumed role is recorded in audit + skipVerifySignature: true, // not re-signing errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "Lambda access", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-east-1"), - })), - request: lambdaRequest, - wantHost: "lambda.us-east-1.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "lambda", - wantAuthCredRegion: "us-east-1", - wantEventType: &events.AppSessionRequest{}, + name: "Lambda access", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-east-1", + request: lambdaRequest, + wantHost: "lambda.us-east-1.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "lambda", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionRequest{}, errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "Request exceeding max size", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-east-1"), - })), - request: maxSizeExceededRequest, - wantHost: "lambda.us-east-1.amazonaws.com", + name: "Request exceeding max size", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-east-1", + request: maxSizeExceededRequest, + wantHost: "lambda.us-east-1.amazonaws.com", errAssertionFns: []require.ErrorAssertionFunc{ // TODO(gavin): change this to [http.StatusRequestEntityTooLarge] // after updating [trace.ErrorToCode]. @@ -426,52 +410,46 @@ func TestAWSSignerHandler(t *testing.T) { }, }, { - name: "AssumeRole success (shorter identity duration)", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-east-1"), - })), - request: assumeRoleRequest(2 * time.Hour), - advanceClock: 10 * time.Minute, - wantHost: "sts.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "sts", - wantAuthCredRegion: "us-east-1", - wantEventType: &events.AppSessionRequest{}, - verifySentRequest: verifyAssumeRoleDuration(50 * time.Minute), // 1h (suite default for identity) - 10m + name: "AssumeRole success (shorter identity duration)", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-east-1", + request: assumeRoleRequest(2 * time.Hour), + advanceClock: 10 * time.Minute, + wantHost: "sts.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "sts", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionRequest{}, + verifySentRequest: verifyAssumeRoleDuration(50 * time.Minute), // 1h (suite default for identity) - 10m errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "AssumeRole success (shorter requested duration)", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-east-1"), - })), - request: assumeRoleRequest(32 * time.Minute), - wantHost: "sts.amazonaws.com", - wantAuthCredKeyID: "FAKEACCESSKEYID", - wantAuthCredService: "sts", - wantAuthCredRegion: "us-east-1", - wantEventType: &events.AppSessionRequest{}, - verifySentRequest: verifyAssumeRoleDuration(32 * time.Minute), // matches the request + name: "AssumeRole success (shorter requested duration)", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-east-1", + request: assumeRoleRequest(32 * time.Minute), + wantHost: "sts.amazonaws.com", + wantAuthCredKeyID: "FAKEACCESSKEYID", + wantAuthCredService: "sts", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionRequest{}, + verifySentRequest: verifyAssumeRoleDuration(32 * time.Minute), // matches the request errAssertionFns: []require.ErrorAssertionFunc{ require.NoError, }, }, { - name: "AssumeRole denied", - app: consoleApp, - awsClientSession: session.Must(session.NewSession(&aws.Config{ - Credentials: staticAWSCredentialsForClient, - Region: aws.String("us-east-1"), - })), - request: assumeRoleRequest(2 * time.Hour), - wantHost: "sts.amazonaws.com", - advanceClock: 50 * time.Minute, // identity is expiring in 10m which is less than minimum + name: "AssumeRole denied", + app: consoleApp, + awsCredentialsProvider: staticAWSCredentialsForClient, + awsRegion: "us-east-1", + request: assumeRoleRequest(2 * time.Hour), + wantHost: "sts.amazonaws.com", + advanceClock: 50 * time.Minute, // identity is expiring in 10m which is less than minimum errAssertionFns: []require.ErrorAssertionFunc{ // the request is 403 forbidden by Teleport, so the mock AWS handler won't be sent anything. hasStatusCode(http.StatusForbidden), @@ -498,7 +476,7 @@ func TestAWSSignerHandler(t *testing.T) { // check that the signature is valid. if !tc.skipVerifySignature { err := awsutils.VerifyAWSSignature(r, - credentialsv2.NewStaticCredentialsProvider(tc.wantAuthCredKeyID, "secret", "token"), + credentials.NewStaticCredentialsProvider(tc.wantAuthCredKeyID, "secret", "token"), ) if !assert.NoError(t, err) { http.Error(w, err.Error(), trace.ErrorToCode(err)) @@ -521,7 +499,7 @@ func TestAWSSignerHandler(t *testing.T) { suite := createSuite(t, mockAwsHandler, tc.app, fakeClock, awsCfgProvider) fakeClock.Advance(tc.advanceClock) - err := tc.request(suite.URL, tc.awsClientSession, tc.wantHost) + err := tc.request(t.Context(), suite.URL, tc.awsRegion, tc.awsCredentialsProvider, tc.wantHost) for _, assertFn := range tc.errAssertionFns { assertFn(t, err) } @@ -625,8 +603,8 @@ func mustNewRequest(t *testing.T, method, url string, body io.Reader) *http.Requ const assumedRoleKeyID = "assumedRoleKeyID" var ( - staticAWSCredentialsForAssumedRole = credentials.NewStaticCredentials(assumedRoleKeyID, "assumedRoleKeySecret", "") - staticAWSCredentialsForClient = credentials.NewStaticCredentials("fakeClientKeyID", "fakeClientSecret", "") + staticAWSCredentialsForAssumedRole = credentials.NewStaticCredentialsProvider(assumedRoleKeyID, "assumedRoleKeySecret", "") + staticAWSCredentialsForClient = credentials.NewStaticCredentialsProvider("fakeClientKeyID", "fakeClientSecret", "") ) type suite struct { diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 67510a8c42c17..5f666738f5b88 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -69,11 +69,6 @@ const ( // azureVirtualMachineCacheTTL is the default TTL for Azure virtual machine // cache entries. azureVirtualMachineCacheTTL = 5 * time.Minute - - // emptyPayloadHash is the SHA-256 for an empty element (as in echo -n | sha256sum). - // PresignHTTP requires the hash of the body, but when there is no body we hash the empty string. - // https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html - emptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" ) // Auth defines interface for creating auth tokens and TLS configurations. @@ -1323,7 +1318,7 @@ func (r *awsRedisIAMTokenRequest) toSignedRequestURI(ctx context.Context) (strin if err != nil { return "", trace.Wrap(err) } - signedURI, _, err := signer.PresignHTTP(ctx, creds, req, emptyPayloadHash, r.serviceName, r.region, r.clock.Now()) + signedURI, _, err := signer.PresignHTTP(ctx, creds, req, awsutils.EmptyPayloadHash, r.serviceName, r.region, r.clock.Now()) if err != nil { return "", trace.Wrap(err) } diff --git a/lib/srv/server/ec2_watcher_test.go b/lib/srv/server/ec2_watcher_test.go index f62cbb737d5f4..cb479a20631bb 100644 --- a/lib/srv/server/ec2_watcher_test.go +++ b/lib/srv/server/ec2_watcher_test.go @@ -22,10 +22,9 @@ import ( "context" "testing" - awsv2 "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2" ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - "github.com/aws/aws-sdk-go/aws" "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" @@ -58,16 +57,16 @@ func (m *mockEC2Client) DescribeInstances(ctx context.Context, input *ec2.Descri func instanceMatches(inst ec2types.Instance, filters []ec2types.Filter) bool { allMatched := true for _, filter := range filters { - name := awsv2.ToString(filter.Name) + name := aws.ToString(filter.Name) val := filter.Values[0] if name == AWSInstanceStateName && inst.State.Name != ec2types.InstanceStateNameRunning { return false } for _, tag := range inst.Tags { - if awsv2.ToString(tag.Key) != name[4:] { + if aws.ToString(tag.Key) != name[4:] { continue } - allMatched = allMatched && awsv2.ToString(tag.Value) != val + allMatched = allMatched && aws.ToString(tag.Value) != val } } @@ -91,7 +90,7 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) { }, expectedFilters: []ec2types.Filter{ { - Name: awsv2.String(AWSInstanceStateName), + Name: aws.String(AWSInstanceStateName), Values: []string{string(ec2types.InstanceStateNameRunning)}, }, }, @@ -105,11 +104,11 @@ func TestNewEC2InstanceFetcherTags(t *testing.T) { }, expectedFilters: []ec2types.Filter{ { - Name: awsv2.String(AWSInstanceStateName), + Name: aws.String(AWSInstanceStateName), Values: []string{string(ec2types.InstanceStateNameRunning)}, }, { - Name: awsv2.String("tag:hello"), + Name: aws.String("tag:hello"), Values: []string{"other"}, }, }, @@ -156,15 +155,15 @@ func TestEC2Watcher(t *testing.T) { ctx := context.Background() present := ec2types.Instance{ - InstanceId: awsv2.String("instance-present"), + InstanceId: aws.String("instance-present"), Tags: []ec2types.Tag{ { - Key: awsv2.String("teleport"), - Value: awsv2.String("yes"), + Key: aws.String("teleport"), + Value: aws.String("yes"), }, { - Key: awsv2.String("Name"), - Value: awsv2.String("Present"), + Key: aws.String("Name"), + Value: aws.String("Present"), }, }, State: &ec2types.InstanceState{ @@ -172,20 +171,20 @@ func TestEC2Watcher(t *testing.T) { }, } presentOther := ec2types.Instance{ - InstanceId: awsv2.String("instance-present-2"), + InstanceId: aws.String("instance-present-2"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("dev"), + Key: aws.String("env"), + Value: aws.String("dev"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, }, } presentForEICE := ec2types.Instance{ - InstanceId: awsv2.String("instance-present-3"), + InstanceId: aws.String("instance-present-3"), Tags: []ec2types.Tag{{ - Key: awsv2.String("with-eice"), - Value: awsv2.String("please"), + Key: aws.String("with-eice"), + Value: aws.String("please"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, @@ -199,23 +198,23 @@ func TestEC2Watcher(t *testing.T) { presentOther, presentForEICE, { - InstanceId: awsv2.String("instance-absent"), + InstanceId: aws.String("instance-absent"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("prod"), + Key: aws.String("env"), + Value: aws.String("prod"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNameRunning, }, }, { - InstanceId: awsv2.String("instance-absent-3"), + InstanceId: aws.String("instance-absent-3"), Tags: []ec2types.Tag{{ - Key: awsv2.String("env"), - Value: awsv2.String("prod"), + Key: aws.String("env"), + Value: aws.String("prod"), }, { - Key: awsv2.String("teleport"), - Value: awsv2.String("yes"), + Key: aws.String("teleport"), + Value: aws.String("yes"), }}, State: &ec2types.InstanceState{ Name: ec2types.InstanceStateNamePending, diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index 6b97e237ca861..036542b181ecd 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -19,9 +19,9 @@ package aws import ( - "bytes" "context" "crypto/sha1" + "crypto/sha256" "encoding/hex" "fmt" "log/slog" @@ -33,14 +33,12 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" - "github.com/aws/aws-sdk-go/aws/credentials" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" "github.com/gravitational/trace" apievents "github.com/gravitational/teleport/api/types/events" apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/utils/aws/migration" ) const ( @@ -77,6 +75,11 @@ const ( // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_iam-quotas.html MaxRoleSessionNameLength = 64 + // EmptyPayloadHash is the SHA-256 for an empty element (as in echo -n | sha256sum). + // PresignHTTP requires the hash of the body, but when there is no body we hash the empty string. + // https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html + EmptyPayloadHash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + iamServiceName = "iam" ) @@ -195,7 +198,8 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider) return trace.Wrap(err) } - reqCopy := req.Clone(context.Background()) + ctx := context.Background() + reqCopy := req.Clone(ctx) // Remove all the headers that are not present in awsCred.SignedHeaders. filterHeaders(reqCopy, sigV4.SignedHeaders) @@ -207,8 +211,13 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider) return trace.BadParameter("%s", err) } - signer := NewSignerV2(credProvider, sigV4.Service) - _, err = signer.Sign(reqCopy, bytes.NewReader(payload), sigV4.Service, sigV4.Region, t) + creds, err := credProvider.Retrieve(ctx) + if err != nil { + return trace.Wrap(err) + } + + signer := NewSignerV2(sigV4.Service) + err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), sigV4.Service, sigV4.Region, t) if err != nil { return trace.Wrap(err) } @@ -227,22 +236,23 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider) } // NewSignerV2 is a temporary AWS SDK migration helper. -func NewSignerV2(provider aws.CredentialsProvider, signingServiceName string) *v4.Signer { - return NewSigner(migration.NewCredentialsAdapter(provider), signingServiceName) -} - -// NewSigner creates a new V4 signer. -func NewSigner(credentials *credentials.Credentials, signingServiceName string) *v4.Signer { - options := func(s *v4.Signer) { +func NewSignerV2(signingServiceName string) *v4.Signer { + return v4.NewSigner(func(opts *v4.SignerOptions) { // s3 and s3control requests are signed with URL unescaped (found by // searching "DisableURIPathEscaping" in "aws-sdk-go/service"). Both // services use "s3" as signing name. See description of // "DisableURIPathEscaping" for more details. if signingServiceName == "s3" { - s.DisableURIPathEscaping = true + opts.DisableURIPathEscaping = true } - } - return v4.NewSigner(credentials, options) + }) +} + +// GetV4PayloadHash returns the V4 signing payload hash. +func GetV4PayloadHash(payload []byte) string { + hash := sha256.New() + hash.Write(payload) + return hex.EncodeToString(hash.Sum(nil)) } // filterHeaders removes request headers that are not in the headers list and returns the removed header keys. diff --git a/lib/utils/aws/credentials.go b/lib/utils/aws/credentials.go deleted file mode 100644 index 8be99d898e5f7..0000000000000 --- a/lib/utils/aws/credentials.go +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package aws - -import ( - "context" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/modules" -) - -// AWSSessionProvider defines a function that creates an AWS Session. -// It must use ambient credentials if Integration is empty. -// It must use Integration credentials otherwise. -type AWSSessionProvider func(ctx context.Context, region string, integration string) (*session.Session, error) - -// StaticAWSSessionProvider is a helper method that returns a static session. -// Must not be used to provide sessions when using Integrations. -func StaticAWSSessionProvider(awsSession *session.Session) AWSSessionProvider { - return func(ctx context.Context, region, integration string) (*session.Session, error) { - if integration != "" { - return nil, trace.BadParameter("integration %q is not allowed to use static sessions", integration) - } - return awsSession, nil - } -} - -// SessionProviderUsingAmbientCredentials returns an AWS Session using ambient credentials. -// This is in contrast with AWS Sessions that can be generated using an AWS OIDC Integration. -func SessionProviderUsingAmbientCredentials() AWSSessionProvider { - return func(ctx context.Context, region, integration string) (*session.Session, error) { - if integration != "" { - return nil, trace.BadParameter("integration %q is not allowed to use ambient sessions", integration) - } - useFIPSEndpoint := endpoints.FIPSEndpointStateUnset - if modules.GetModules().IsBoringBinary() { - useFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - session, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigEnable, - Config: aws.Config{ - UseFIPSEndpoint: useFIPSEndpoint, - }, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return session, nil - } -} diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index 31d29532c20d8..d98311ec90c56 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -19,10 +19,10 @@ package aws import ( - "bytes" "context" "io" "net/http" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/gravitational/trace" @@ -94,8 +94,14 @@ func SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (* // 100-continue" headers without being signed, otherwise the Athena service // would reject the requests. unsignedHeaders := removeUnsignedHeaders(reqCopy) - signer := NewSignerV2(signCtx.Credentials, signCtx.SigningName) - _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, signCtx.Clock.Now()) + + creds, err := signCtx.Credentials.Retrieve(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + signer := NewSignerV2(signCtx.SigningName) + err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), signCtx.SigningName, signCtx.SigningRegion, time.Now()) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/utils/aws/xml.go b/lib/utils/aws/xml.go index e88ca680a5934..69e9c5b84c780 100644 --- a/lib/utils/aws/xml.go +++ b/lib/utils/aws/xml.go @@ -22,62 +22,52 @@ import ( "bytes" "encoding/xml" - "github.com/aws/aws-sdk-go/private/protocol/xml/xmlutil" + smithyxml "github.com/aws/smithy-go/encoding/xml" + "github.com/gravitational/trace" ) // IsXMLOfLocalName returns true if the root XML has the provided (local) name. func IsXMLOfLocalName(data []byte, wantLocalName string) bool { - var name xml.Name - if err := xml.Unmarshal(data, &name); err == nil { - return wantLocalName == name.Local + st, err := smithyxml.FetchRootElement(xml.NewDecoder(bytes.NewReader(data))) + if err == nil && st.Name.Local == wantLocalName { + return true } + return false } // UnmarshalXMLChildNode decodes the XML-encoded data and stores the child node // with the specified name to v, where v is a pointer to an AWS SDK v1 struct. func UnmarshalXMLChildNode(v interface{}, data []byte, childName string) error { - return trace.Wrap(xmlutil.UnmarshalXML(v, xml.NewDecoder(bytes.NewReader(data)), childName)) + decoder := xml.NewDecoder(bytes.NewReader(data)) + st, err := smithyxml.FetchRootElement(decoder) + if err != nil { + return trace.Wrap(err) + } + nodeDecoder := smithyxml.WrapNodeDecoder(decoder, st) + childElem, err := nodeDecoder.GetElement(childName) + if err != nil { + return trace.Wrap(err) + } + + return trace.Wrap(decoder.DecodeElement(v, &childElem)) } // MarshalXML marshals the provided root name and a map of children in XML with // default indent (prefix "", indent " "). -func MarshalXML(rootName xml.Name, children map[string]any) ([]byte, error) { +func MarshalXML(root string, namespace string, v any) ([]byte, error) { var buf bytes.Buffer encoder := xml.NewEncoder(&buf) encoder.Indent("", " ") - - err := encodeXMLNode(encoder, rootName, func() error { - for childName, childValue := range children { - if err := encodeXMLNodeAWSSDKV1(encoder, childName, childValue); err != nil { - return trace.Wrap(err) - } - } - return nil + err := encoder.EncodeElement(v, xml.StartElement{ + Name: xml.Name{Local: root}, + Attr: []xml.Attr{ + {Name: xml.Name{Local: "xmlns"}, Value: namespace}, + }, }) if err != nil { return nil, trace.Wrap(err) } - if err := trace.Wrap(encoder.Flush()); err != nil { - return nil, trace.Wrap(err) - } return buf.Bytes(), nil } - -func encodeXMLNode(encoder *xml.Encoder, name xml.Name, encodeChildren func() error) error { - startElement := xml.StartElement{Name: name} - if err := encoder.EncodeToken(startElement); err != nil { - return trace.Wrap(err) - } - if err := encodeChildren(); err != nil { - return trace.Wrap(err) - } - return trace.Wrap(encoder.EncodeToken(startElement.End())) -} - -func encodeXMLNodeAWSSDKV1(encoder *xml.Encoder, name string, v any) error { - return encodeXMLNode(encoder, xml.Name{Local: name}, func() error { - return trace.Wrap(xmlutil.BuildXML(v, encoder)) - }) -} diff --git a/lib/utils/aws/xml_test.go b/lib/utils/aws/xml_test.go index f2553e6a4cb9c..b4bb39d6f14ad 100644 --- a/lib/utils/aws/xml_test.go +++ b/lib/utils/aws/xml_test.go @@ -19,15 +19,13 @@ package aws import ( - "encoding/xml" - "net/http" "strings" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/private/protocol" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + ststypes "github.com/aws/aws-sdk-go-v2/service/sts/types" "github.com/stretchr/testify/require" ) @@ -40,10 +38,10 @@ func TestIsXMLOfLocalName(t *testing.T) { func TestUnmarshalXMLChildNode(t *testing.T) { want := sts.AssumeRoleOutput{ - AssumedRoleUser: &sts.AssumedRoleUser{ + AssumedRoleUser: &ststypes.AssumedRoleUser{ Arn: aws.String("some-arn"), }, - Credentials: &sts.Credentials{ + Credentials: &ststypes.Credentials{ AccessKeyId: aws.String("some-access-key-id"), SecretAccessKey: aws.String("some-secret-access-key"), SessionToken: aws.String("some-session-token"), @@ -75,49 +73,22 @@ func TestUnmarshalXMLChildNode(t *testing.T) { } func TestMarshalXMLIndent(t *testing.T) { - data, err := MarshalXML( - xml.Name{ - Local: "AssumeRoleResponse", - Space: "https://sts.amazonaws.com/doc/2011-06-15/", - }, - map[string]any{ - "AssumeRoleResult": sts.AssumeRoleOutput{ - AssumedRoleUser: &sts.AssumedRoleUser{ - Arn: aws.String("some-arn"), - }, - Credentials: &sts.Credentials{ - AccessKeyId: aws.String("some-access-key-id"), - SecretAccessKey: aws.String("some-secret-access-key"), - SessionToken: aws.String("some-session-token"), - Expiration: aws.Time(time.Unix(1234567890, 0).UTC()), - }, - }, - "ResponseMetadata": protocol.ResponseMetadata{ - RequestID: "some-request-id", - StatusCode: http.StatusOK, - }, - }, - ) + simpleAssumeRole := struct { + Test string + Encoding time.Time + }{ + Test: "test", + Encoding: time.Unix(1234567890, 0).UTC(), + } + + data, err := MarshalXML("AssumeRoleResponse", "https://sts.amazonaws.com/doc/2011-06-15/", simpleAssumeRole) require.NoError(t, err) // Nodes are not sorted. Use ElementsMatch to ensure each line is present. require.ElementsMatch(t, []string{ ``, - ` `, - ` `, - ` some-secret-access-key`, - ` some-session-token`, - ` some-access-key-id`, - ` 2009-02-13T23:31:30Z`, - ` `, - ` `, - ` some-arn`, - ` `, - ` `, - ` `, - ` 200`, - ` some-request-id`, - ` `, + ` test`, + ` 2009-02-13T23:31:30Z`, ``, }, strings.Split(string(data), "\n")) } From 9ba673e14bc1ff77c8e058fa6c44dffd0ba56dad Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Wed, 12 Mar 2025 15:30:35 -0300 Subject: [PATCH 2/4] refactor(aws): remove legacy error conversion for s3 --- lib/utils/aws/s3.go | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/lib/utils/aws/s3.go b/lib/utils/aws/s3.go index cd050dff34ce9..c6fcb0275c736 100644 --- a/lib/utils/aws/s3.go +++ b/lib/utils/aws/s3.go @@ -22,15 +22,12 @@ import ( "context" "errors" "io" - "net/http" "strings" awsv2 "github.com/aws/aws-sdk-go-v2/aws" managerv2 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3v2 "github.com/aws/aws-sdk-go-v2/service/s3" s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/smithy-go" "github.com/gravitational/trace" ) @@ -42,25 +39,6 @@ func ConvertS3Error(err error) error { return nil } - // SDK v1 errors: - var rerr awserr.RequestFailure - if errors.As(err, &rerr) && rerr.StatusCode() == http.StatusForbidden { - return trace.AccessDenied("%s", rerr.Message()) - } - - var aerr awserr.Error - if errors.As(err, &aerr) { - switch aerr.Code() { - case s3.ErrCodeNoSuchKey, s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchUpload, "NotFound": - return trace.NotFound("%s", aerr) - case s3.ErrCodeBucketAlreadyExists, s3.ErrCodeBucketAlreadyOwnedByYou: - return trace.AlreadyExists("%s", aerr) - default: - return trace.BadParameter("%s", aerr) - } - } - - // SDK v2 errors: var noSuchKey *s3types.NoSuchKey if errors.As(err, &noSuchKey) { return trace.NotFound("%s", noSuchKey) From 621f3f462f1c25dda84b70de2850d95b46964cee Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Wed, 12 Mar 2025 15:44:23 -0300 Subject: [PATCH 3/4] refactor(awsoidc): remove unused clientv1 --- lib/integrations/awsoidc/clientsv1.go | 124 -------------------- lib/integrations/awsoidc/clientsv1_test.go | 129 --------------------- 2 files changed, 253 deletions(-) delete mode 100644 lib/integrations/awsoidc/clientsv1.go delete mode 100644 lib/integrations/awsoidc/clientsv1_test.go diff --git a/lib/integrations/awsoidc/clientsv1.go b/lib/integrations/awsoidc/clientsv1.go deleted file mode 100644 index 576badf79951d..0000000000000 --- a/lib/integrations/awsoidc/clientsv1.go +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package awsoidc - -import ( - "context" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/api/types" - utilsaws "github.com/gravitational/teleport/api/utils/aws" - "github.com/gravitational/teleport/lib/modules" - "github.com/gravitational/teleport/lib/utils/aws/stsutils" -) - -// FetchToken returns the token. -func (j IdentityToken) FetchToken(ctx credentials.Context) ([]byte, error) { - return []byte(j), nil -} - -// IntegrationTokenGenerator is an interface that indicates which APIs are required to generate an Integration Token. -type IntegrationTokenGenerator interface { - // GetIntegration returns the specified integration resources. - GetIntegration(ctx context.Context, name string) (types.Integration, error) - - // GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action. - GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) -} - -// NewSessionV1 creates a new AWS Session for the region using the integration as source of credentials. -// This session is usable for AWS SDK Go V1. -func NewSessionV1(ctx context.Context, client IntegrationTokenGenerator, region string, integrationName string) (*session.Session, error) { - if region != "" { - if err := utilsaws.IsValidRegion(region); err != nil { - return nil, trace.Wrap(err) - } - } - integration, err := client.GetIntegration(ctx, integrationName) - if err != nil { - return nil, trace.Wrap(err) - } - - awsOIDCIntegration := integration.GetAWSOIDCIntegrationSpec() - if awsOIDCIntegration == nil { - return nil, trace.BadParameter("invalid integration subkind, expected awsoidc, got %s", integration.GetSubKind()) - } - - useFIPSEndpoint := endpoints.FIPSEndpointStateUnset - if modules.GetModules().IsBoringBinary() { - useFIPSEndpoint = endpoints.FIPSEndpointStateEnabled - } - - sess, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigDisable, - Config: aws.Config{ - UseFIPSEndpoint: useFIPSEndpoint, - }, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - // AWS SDK calls FetchToken everytime the session is no longer valid (or there's no session). - // Generating a token here and using it as a Static would make this token valid for the Max Duration Session for the current AWS Role (usually, 1 hour). - // Instead, it generates a token everytime the Session's client requests a new token, ensuring it always receives a fresh one. - var integrationTokenFetcher IntegrationTokenFetcher = func(ctx context.Context) ([]byte, error) { - token, err := client.GenerateAWSOIDCToken(ctx, integrationName) - return []byte(token), trace.Wrap(err) - } - - stsSTS := stsutils.NewV1(sess) - roleProvider := stscreds.NewWebIdentityRoleProviderWithOptions( - stsSTS, - awsOIDCIntegration.RoleARN, - "", - integrationTokenFetcher, - ) - awsCredentials := credentials.NewCredentials(roleProvider) - - session, err := session.NewSessionWithOptions(session.Options{ - SharedConfigState: session.SharedConfigDisable, - Config: aws.Config{ - Region: aws.String(region), - Credentials: awsCredentials, - UseFIPSEndpoint: useFIPSEndpoint, - }, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - return session, nil -} - -// IntegrationTokenFetcher handles dynamic token generation using a callback function. -// Useful to embed as a [stscreds.TokenFetcher]. -type IntegrationTokenFetcher func(context.Context) ([]byte, error) - -// FetchToken returns a token by calling the callback function. -func (genFn IntegrationTokenFetcher) FetchToken(ctx context.Context) ([]byte, error) { - token, err := genFn(ctx) - return token, trace.Wrap(err) -} diff --git a/lib/integrations/awsoidc/clientsv1_test.go b/lib/integrations/awsoidc/clientsv1_test.go deleted file mode 100644 index 66a350ebe67a6..0000000000000 --- a/lib/integrations/awsoidc/clientsv1_test.go +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package awsoidc - -import ( - "context" - "testing" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/google/uuid" - "github.com/gravitational/trace" - "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/api/types" -) - -type mockIntegrationsTokenGenerator struct { - proxies []types.Server - integrations map[string]types.Integration - tokenCallsCount int -} - -// GetIntegration returns the specified integration resources. -func (m *mockIntegrationsTokenGenerator) GetIntegration(ctx context.Context, name string) (types.Integration, error) { - if ig, found := m.integrations[name]; found { - return ig, nil - } - - return nil, trace.NotFound("integration not found") -} - -// GetProxies returns a list of registered proxies. -func (m *mockIntegrationsTokenGenerator) GetProxies() ([]types.Server, error) { - return m.proxies, nil -} - -// GenerateAWSOIDCToken generates a token to be used to execute an AWS OIDC Integration action. -func (m *mockIntegrationsTokenGenerator) GenerateAWSOIDCToken(ctx context.Context, integration string) (string, error) { - m.tokenCallsCount++ - return uuid.NewString(), nil -} - -func TestNewSessionV1(t *testing.T) { - ctx := context.Background() - - dummyIntegration, err := types.NewIntegrationAWSOIDC( - types.Metadata{Name: "myawsintegration"}, - &types.AWSOIDCIntegrationSpecV1{ - RoleARN: "arn:aws:sts::123456789012:role/TestRole", - }, - ) - require.NoError(t, err) - - dummyProxy, err := types.NewServer( - "proxy-123", types.KindProxy, - types.ServerSpecV2{ - PublicAddrs: []string{"https://localhost:3080/"}, - }, - ) - require.NoError(t, err) - - for _, tt := range []struct { - name string - region string - integration string - tokenFetchCount int - expectedErr require.ErrorAssertionFunc - sessionValidator func(*testing.T, *session.Session) - }{ - { - name: "valid", - region: "us-dummy-1", - integration: "myawsintegration", - expectedErr: require.NoError, - sessionValidator: func(t *testing.T, s *session.Session) { - require.Equal(t, aws.String("us-dummy-1"), s.Config.Region) - }, - }, - { - name: "valid with empty region", - region: "", - integration: "myawsintegration", - expectedErr: require.NoError, - sessionValidator: func(t *testing.T, s *session.Session) { - require.Equal(t, "", aws.StringValue(s.Config.Region)) - }, - }, - { - name: "not found error when integration is missing", - region: "us-dummy-1", - integration: "not-found", - expectedErr: notFoundCheck, - }, - } { - t.Run(tt.name, func(t *testing.T) { - mockTokenGenertor := &mockIntegrationsTokenGenerator{ - proxies: []types.Server{dummyProxy}, - integrations: map[string]types.Integration{ - dummyIntegration.GetName(): dummyIntegration, - }, - } - awsSessionOut, err := NewSessionV1(ctx, mockTokenGenertor, tt.region, tt.integration) - - tt.expectedErr(t, err) - if tt.sessionValidator != nil { - tt.sessionValidator(t, awsSessionOut) - } - require.Zero(t, tt.tokenFetchCount) - }) - } - -} From 6f80f34dcd96b520e1a753f89e399fa191985340 Mon Sep 17 00:00:00 2001 From: Gabriel Corado Date: Wed, 12 Mar 2025 15:50:14 -0300 Subject: [PATCH 4/4] refactor: code review suggestions --- lib/srv/alpnproxy/aws_local_proxy_test.go | 4 ++-- lib/utils/aws/aws.go | 10 +++++++--- lib/utils/aws/signing.go | 2 +- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/lib/srv/alpnproxy/aws_local_proxy_test.go b/lib/srv/alpnproxy/aws_local_proxy_test.go index 75591e1062f98..083001fd30bae 100644 --- a/lib/srv/alpnproxy/aws_local_proxy_test.go +++ b/lib/srv/alpnproxy/aws_local_proxy_test.go @@ -45,10 +45,10 @@ func TestAWSAccessMiddleware(t *testing.T) { require.NoError(t, m.CheckAndSetDefaults()) stsRequestByLocalProxyCred := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", nil) - awsutils.NewSignerV2("sts").SignHTTP(t.Context(), localCred, stsRequestByLocalProxyCred, awsutils.EmptyPayloadHash, "sts", "us-west-1", time.Now()) + awsutils.NewSigner("sts").SignHTTP(t.Context(), localCred, stsRequestByLocalProxyCred, awsutils.EmptyPayloadHash, "sts", "us-west-1", time.Now()) requestByAssumedRole := httptest.NewRequest(http.MethodGet, "http://s3.amazonaws.com", nil) - awsutils.NewSignerV2("s3").SignHTTP(t.Context(), assumedRoleCred, requestByAssumedRole, awsutils.EmptyPayloadHash, "s3", "us-west-1", time.Now()) + awsutils.NewSigner("s3").SignHTTP(t.Context(), assumedRoleCred, requestByAssumedRole, awsutils.EmptyPayloadHash, "s3", "us-west-1", time.Now()) t.Run("request no authorization", func(t *testing.T) { recorder := httptest.NewRecorder() diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index 036542b181ecd..106ccd8e7f24c 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -216,7 +216,7 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider) return trace.Wrap(err) } - signer := NewSignerV2(sigV4.Service) + signer := NewSigner(sigV4.Service) err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), sigV4.Service, sigV4.Region, t) if err != nil { return trace.Wrap(err) @@ -235,8 +235,8 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider) return nil } -// NewSignerV2 is a temporary AWS SDK migration helper. -func NewSignerV2(signingServiceName string) *v4.Signer { +// NewSigner creates a new V4 signer. +func NewSigner(signingServiceName string) *v4.Signer { return v4.NewSigner(func(opts *v4.SignerOptions) { // s3 and s3control requests are signed with URL unescaped (found by // searching "DisableURIPathEscaping" in "aws-sdk-go/service"). Both @@ -250,6 +250,10 @@ func NewSignerV2(signingServiceName string) *v4.Signer { // GetV4PayloadHash returns the V4 signing payload hash. func GetV4PayloadHash(payload []byte) string { + if len(payload) == 0 { + return EmptyPayloadHash + } + hash := sha256.New() hash.Write(payload) return hex.EncodeToString(hash.Sum(nil)) diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index d98311ec90c56..385ccfdcbae79 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -100,7 +100,7 @@ func SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (* return nil, trace.Wrap(err) } - signer := NewSignerV2(signCtx.SigningName) + signer := NewSigner(signCtx.SigningName) err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), signCtx.SigningName, signCtx.SigningRegion, time.Now()) if err != nil { return nil, trace.Wrap(err)