diff --git a/internal/extproc/backendauth/aws.go b/internal/extproc/backendauth/aws.go index 618964fe3..774677f00 100644 --- a/internal/extproc/backendauth/aws.go +++ b/internal/extproc/backendauth/aws.go @@ -22,39 +22,27 @@ import ( // awsHandler implements [Handler] for AWS Bedrock authz. type awsHandler struct { - credentials aws.Credentials - signer *v4.Signer - region string + credentials aws.Credentials + credentialFileName string + signer *v4.Signer + region string } func newAWSHandler(awsAuth *filterconfig.AWSAuth) (*awsHandler, error) { var credentials aws.Credentials var region string + var credentialFileName string // TODO: refactor to work with refreshing credentials (similar to API Key) if awsAuth != nil { region = awsAuth.Region - if len(awsAuth.CredentialFileName) != 0 { - cfg, err := config.LoadDefaultConfig( - context.Background(), - config.WithSharedCredentialsFiles([]string{awsAuth.CredentialFileName}), - config.WithRegion(awsAuth.Region), - ) - if err != nil { - return nil, fmt.Errorf("cannot load from credentials file: %w", err) - } - credentials, err = cfg.Credentials.Retrieve(context.Background()) - if err != nil { - return nil, fmt.Errorf("cannot retrieve AWS credentials: %w", err) - } - } + credentialFileName = awsAuth.CredentialFileName } else { return nil, fmt.Errorf("aws auth configuration is required") } signer := v4.NewSigner() - - return &awsHandler{credentials: credentials, signer: signer, region: region}, nil + return &awsHandler{credentials: credentials, credentialFileName: credentialFileName, signer: signer, region: region}, nil } // Do implements [Handler.Do]. @@ -83,6 +71,19 @@ func (a *awsHandler) Do(requestHeaders map[string]string, headerMut *extprocv3.H body = _body } + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithSharedCredentialsFiles([]string{a.credentialFileName}), + config.WithRegion(a.region), + ) + if err != nil { + return fmt.Errorf("cannot load from credentials file: %w", err) + } + credentials, err := cfg.Credentials.Retrieve(context.Background()) + if err != nil { + return fmt.Errorf("cannot retrieve AWS credentials: %w", err) + } + payloadHash := sha256.Sum256(body) req, err := http.NewRequest(method, fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com%s", a.region, path), @@ -91,7 +92,7 @@ func (a *awsHandler) Do(requestHeaders map[string]string, headerMut *extprocv3.H return fmt.Errorf("cannot create request: %w", err) } - err = a.signer.SignHTTP(context.Background(), a.credentials, req, + err = a.signer.SignHTTP(context.Background(), credentials, req, hex.EncodeToString(payloadHash[:]), "bedrock", a.region, time.Now()) if err != nil { return fmt.Errorf("cannot sign request: %w", err)