Skip to content

Commit

Permalink
Merge pull request #3 from So-Sahari/feature/limit-bedrock-list
Browse files Browse the repository at this point in the history
feat: limit list models to supported model structs
  • Loading branch information
catpaladin authored Jun 22, 2024
2 parents cd90d7e + 50d184d commit 668884f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
3 changes: 3 additions & 0 deletions internal/bedrock/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ import (
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
)

// SupportedModels contains all supported models implemented
var SupportedModels = []string{"sonnet", "anthropic", "cohere", "mistral"}

// AWSModelConfig contains all settings for AWS Models
type AWSModelConfig struct {
ModelID string // ModelID is the ID of the model to invoke
Expand Down
23 changes: 18 additions & 5 deletions internal/bedrock/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import (
"fmt"
"sort"
"strconv"
"strings"

"jenn-ai/internal/fuzzy"

"github.com/aws/aws-sdk-go-v2/service/bedrock"
"github.com/aws/aws-sdk-go-v2/service/bedrock/types"
)

// FoundationModel contains model fields
Expand Down Expand Up @@ -38,17 +40,28 @@ func SelectBedrockModel(ctx context.Context, client *bedrock.Client) (string, er
func ListModels(ctx context.Context, api ClientAPI) ([]FoundationModel, error) {
var output []FoundationModel

response, err := api.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{})
response, err := api.ListFoundationModels(ctx, &bedrock.ListFoundationModelsInput{
ByOutputModality: types.ModelModalityText,
})
if err != nil {
return output, err
}

filteredModels := []types.FoundationModelSummary{}
for _, model := range response.ModelSummaries {
for _, m := range SupportedModels {
if strings.Contains(*model.ModelId, m) {
filteredModels = append(filteredModels, model)
}
}
}

for _, sm := range filteredModels {
output = append(output, FoundationModel{
Name: *model.ModelName,
Provider: *model.ProviderName,
ID: *model.ModelId,
Modality: fmt.Sprintf("%v", model.OutputModalities),
Name: *sm.ModelName,
Provider: *sm.ProviderName,
ID: *sm.ModelId,
Modality: fmt.Sprintf("%v", sm.OutputModalities),
})
}
return output, nil
Expand Down

0 comments on commit 668884f

Please sign in to comment.