Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

extproc: retrieve AWS credentials for every request #185

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions internal/extproc/backendauth/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,27 @@ import (

// awsHandler implements [Handler] for AWS Bedrock authz.
type awsHandler struct {
credentials aws.Credentials
signer *v4.Signer
region string
credentials aws.Credentials
credentialFileName string
signer *v4.Signer
region string
}

func newAWSHandler(awsAuth *filterconfig.AWSAuth) (*awsHandler, error) {
var credentials aws.Credentials
var region string
var credentialFileName string

// TODO: refactor to work with refreshing credentials (similar to API Key)
if awsAuth != nil {
region = awsAuth.Region
if len(awsAuth.CredentialFileName) != 0 {
cfg, err := config.LoadDefaultConfig(
context.Background(),
config.WithSharedCredentialsFiles([]string{awsAuth.CredentialFileName}),
config.WithRegion(awsAuth.Region),
)
if err != nil {
return nil, fmt.Errorf("cannot load from credentials file: %w", err)
}
credentials, err = cfg.Credentials.Retrieve(context.Background())
if err != nil {
return nil, fmt.Errorf("cannot retrieve AWS credentials: %w", err)
}
}
credentialFileName = awsAuth.CredentialFileName
} else {
return nil, fmt.Errorf("aws auth configuration is required")
}

signer := v4.NewSigner()

return &awsHandler{credentials: credentials, signer: signer, region: region}, nil
return &awsHandler{credentials: credentials, credentialFileName: credentialFileName, signer: signer, region: region}, nil
}

// Do implements [Handler.Do].
Expand Down Expand Up @@ -83,6 +71,19 @@ func (a *awsHandler) Do(requestHeaders map[string]string, headerMut *extprocv3.H
body = _body
}

cfg, err := config.LoadDefaultConfig(
context.Background(),
config.WithSharedCredentialsFiles([]string{a.credentialFileName}),
config.WithRegion(a.region),
)
if err != nil {
return fmt.Errorf("cannot load from credentials file: %w", err)
}
credentials, err := cfg.Credentials.Retrieve(context.Background())
if err != nil {
return fmt.Errorf("cannot retrieve AWS credentials: %w", err)
}

payloadHash := sha256.Sum256(body)
req, err := http.NewRequest(method,
fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com%s", a.region, path),
Expand All @@ -91,7 +92,7 @@ func (a *awsHandler) Do(requestHeaders map[string]string, headerMut *extprocv3.H
return fmt.Errorf("cannot create request: %w", err)
}

err = a.signer.SignHTTP(context.Background(), a.credentials, req,
err = a.signer.SignHTTP(context.Background(), credentials, req,
hex.EncodeToString(payloadHash[:]), "bedrock", a.region, time.Now())
if err != nil {
return fmt.Errorf("cannot sign request: %w", err)
Expand Down
Loading