Skip to content

Commit

Permalink
extproc: remove the path from the translator factory (#334)
Browse files Browse the repository at this point in the history
**Commit Message**

extproc: remove the path from the translator factory

Removes the path from the translator factory, now that there is a
dedicated processor for the chat completion endpoint.

**Related Issues/PRs (if applicable)**

Follow-up for:
#325 (review)

**Special notes for reviewers (if applicable)**

Note that I don't remove the `Factory` type completely so that only the
right translator is instantiated and only when needed.

---------

Signed-off-by: Ignasi Barrera <ignasi@tetrate.io>
  • Loading branch information
nacx authored Feb 13, 2025
1 parent 3aeb52d commit f49dc98
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 206 deletions.
120 changes: 65 additions & 55 deletions internal/extproc/chatcompletion_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,42 +38,52 @@ type chatCompletionProcessor struct {
costs translator.LLMTokenUsage
}

// selectTranslator selects the translator based on the output schema.
func (c *chatCompletionProcessor) selectTranslator(out filterapi.VersionedAPISchema) error {
if c.translator != nil { // Prevents re-selection and allows translator injection in tests.
return nil
}
// TODO: currently, we ignore the LLMAPISchema."Version" field.
switch out.Name {
case filterapi.APISchemaOpenAI:
c.translator = translator.NewChatCompletionOpenAIToOpenAITranslator()
case filterapi.APISchemaAWSBedrock:
c.translator = translator.NewChatCompletionOpenAIToAWSBedrockTranslator()
default:
return fmt.Errorf("unsupported API schema: backend=%s", out)
}
return nil
}

// ProcessRequestHeaders implements [ProcessorIface.ProcessRequestHeaders].
func (p *chatCompletionProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
func (c *chatCompletionProcessor) ProcessRequestHeaders(_ context.Context, _ *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
// The request headers have already been at the time the processor was created
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_RequestHeaders{
RequestHeaders: &extprocv3.HeadersResponse{},
}}, nil
}

// ProcessRequestBody implements [ProcessorIface.ProcessRequestBody].
func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := p.requestHeaders[":path"]
model, body, err := p.config.bodyParser(path, rawBody)
func (c *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBody *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
path := c.requestHeaders[":path"]
model, body, err := c.config.bodyParser(path, rawBody)
if err != nil {
return nil, fmt.Errorf("failed to parse request body: %w", err)
}
p.logger.Info("Processing request", "path", path, "model", model)
c.logger.Info("Processing request", "path", path, "model", model)

p.requestHeaders[p.config.modelNameHeaderKey] = model
b, err := p.config.router.Calculate(p.requestHeaders)
c.requestHeaders[c.config.modelNameHeaderKey] = model
b, err := c.config.router.Calculate(c.requestHeaders)
if err != nil {
return nil, fmt.Errorf("failed to calculate route: %w", err)
}
p.logger.Info("Selected backend", "backend", b.Name)
c.logger.Info("Selected backend", "backend", b.Name)

factory, ok := p.config.factories[b.Schema]
if !ok {
return nil, fmt.Errorf("failed to find factory for output schema %q", b.Schema)
}

t, err := factory(path)
if err != nil {
return nil, fmt.Errorf("failed to create translator: %w", err)
if err = c.selectTranslator(b.Schema); err != nil {
return nil, fmt.Errorf("failed to select translator: %w", err)
}
p.translator = t

headerMutation, bodyMutation, override, err := p.translator.RequestBody(body)
headerMutation, bodyMutation, override, err := c.translator.RequestBody(body)
if err != nil {
return nil, fmt.Errorf("failed to transform request: %w", err)
}
Expand All @@ -83,13 +93,13 @@ func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBod
}
// Set the model name to the request header with the key `x-ai-gateway-llm-model-name`.
headerMutation.SetHeaders = append(headerMutation.SetHeaders, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.modelNameHeaderKey, RawValue: []byte(model)},
Header: &corev3.HeaderValue{Key: c.config.modelNameHeaderKey, RawValue: []byte(model)},
}, &corev3.HeaderValueOption{
Header: &corev3.HeaderValue{Key: p.config.selectedBackendHeaderKey, RawValue: []byte(b.Name)},
Header: &corev3.HeaderValue{Key: c.config.selectedBackendHeaderKey, RawValue: []byte(b.Name)},
})

if authHandler, ok := p.config.backendAuthHandlers[b.Name]; ok {
if err := authHandler.Do(ctx, p.requestHeaders, headerMutation, bodyMutation); err != nil {
if authHandler, ok := c.config.backendAuthHandlers[b.Name]; ok {
if err := authHandler.Do(ctx, c.requestHeaders, headerMutation, bodyMutation); err != nil {
return nil, fmt.Errorf("failed to do auth request: %w", err)
}
}
Expand All @@ -110,19 +120,19 @@ func (p *chatCompletionProcessor) ProcessRequestBody(ctx context.Context, rawBod
}

// ProcessResponseHeaders implements [ProcessorIface.ProcessResponseHeaders].
func (p *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
p.responseHeaders = headersToMap(headers)
if enc := p.responseHeaders["content-encoding"]; enc != "" {
p.responseEncoding = enc
func (c *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, headers *corev3.HeaderMap) (res *extprocv3.ProcessingResponse, err error) {
c.responseHeaders = headersToMap(headers)
if enc := c.responseHeaders["content-encoding"]; enc != "" {
c.responseEncoding = enc
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
if c.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseHeaders{
ResponseHeaders: &extprocv3.HeadersResponse{},
}}, nil
}
headerMutation, err := p.translator.ResponseHeaders(p.responseHeaders)
headerMutation, err := c.translator.ResponseHeaders(c.responseHeaders)
if err != nil {
return nil, fmt.Errorf("failed to transform response headers: %w", err)
}
Expand All @@ -134,9 +144,9 @@ func (p *chatCompletionProcessor) ProcessResponseHeaders(_ context.Context, head
}

// ProcessResponseBody implements [ProcessorIface.ProcessResponseBody].
func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
func (c *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *extprocv3.HttpBody) (res *extprocv3.ProcessingResponse, err error) {
var br io.Reader
switch p.responseEncoding {
switch c.responseEncoding {
case "gzip":
br, err = gzip.NewReader(bytes.NewReader(body.Body))
if err != nil {
Expand All @@ -147,11 +157,11 @@ func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *e
}
// The translator can be nil as there could be response event generated by previous ext proc without
// getting the request event.
if p.translator == nil {
if c.translator == nil {
return &extprocv3.ProcessingResponse{Response: &extprocv3.ProcessingResponse_ResponseBody{}}, nil
}

headerMutation, bodyMutation, tokenUsage, err := p.translator.ResponseBody(p.responseHeaders, br, body.EndOfStream)
headerMutation, bodyMutation, tokenUsage, err := c.translator.ResponseBody(c.responseHeaders, br, body.EndOfStream)
if err != nil {
return nil, fmt.Errorf("failed to transform response: %w", err)
}
Expand All @@ -168,55 +178,55 @@ func (p *chatCompletionProcessor) ProcessResponseBody(_ context.Context, body *e
}

// TODO: this is coupled with "LLM" specific logic. Once we have another use case, we need to refactor this.
p.costs.InputTokens += tokenUsage.InputTokens
p.costs.OutputTokens += tokenUsage.OutputTokens
p.costs.TotalTokens += tokenUsage.TotalTokens
if body.EndOfStream && len(p.config.requestCosts) > 0 {
resp.DynamicMetadata, err = p.maybeBuildDynamicMetadata()
c.costs.InputTokens += tokenUsage.InputTokens
c.costs.OutputTokens += tokenUsage.OutputTokens
c.costs.TotalTokens += tokenUsage.TotalTokens
if body.EndOfStream && len(c.config.requestCosts) > 0 {
resp.DynamicMetadata, err = c.maybeBuildDynamicMetadata()
if err != nil {
return nil, fmt.Errorf("failed to build dynamic metadata: %w", err)
}
}
return resp, nil
}

func (p *chatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(p.config.requestCosts))
for i := range p.config.requestCosts {
c := &p.config.requestCosts[i]
func (c *chatCompletionProcessor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
metadata := make(map[string]*structpb.Value, len(c.config.requestCosts))
for i := range c.config.requestCosts {
rc := &c.config.requestCosts[i]
var cost uint32
switch c.Type {
switch rc.Type {
case filterapi.LLMRequestCostTypeInputToken:
cost = p.costs.InputTokens
cost = c.costs.InputTokens
case filterapi.LLMRequestCostTypeOutputToken:
cost = p.costs.OutputTokens
cost = c.costs.OutputTokens
case filterapi.LLMRequestCostTypeTotalToken:
cost = p.costs.TotalTokens
cost = c.costs.TotalTokens
case filterapi.LLMRequestCostTypeCELExpression:
costU64, err := llmcostcel.EvaluateProgram(
c.celProg,
p.requestHeaders[p.config.modelNameHeaderKey],
p.requestHeaders[p.config.selectedBackendHeaderKey],
p.costs.InputTokens,
p.costs.OutputTokens,
p.costs.TotalTokens,
rc.celProg,
c.requestHeaders[c.config.modelNameHeaderKey],
c.requestHeaders[c.config.selectedBackendHeaderKey],
c.costs.InputTokens,
c.costs.OutputTokens,
c.costs.TotalTokens,
)
if err != nil {
return nil, fmt.Errorf("failed to evaluate CEL expression: %w", err)
}
cost = uint32(costU64) //nolint:gosec
default:
return nil, fmt.Errorf("unknown request cost kind: %s", c.Type)
return nil, fmt.Errorf("unknown request cost kind: %s", rc.Type)
}
p.logger.Info("Setting request cost metadata", "type", c.Type, "cost", cost, "metadataKey", c.MetadataKey)
metadata[c.MetadataKey] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(cost)}}
c.logger.Info("Setting request cost metadata", "type", rc.Type, "cost", cost, "metadataKey", rc.MetadataKey)
metadata[rc.MetadataKey] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: float64(cost)}}
}
if len(metadata) == 0 {
return nil, nil
}
return &structpb.Struct{
Fields: map[string]*structpb.Value{
p.config.metadataNamespace: {
c.config.metadataNamespace: {
Kind: &structpb.Value_StructValue{
StructValue: &structpb.Struct{Fields: metadata},
},
Expand Down
51 changes: 22 additions & 29 deletions internal/extproc/chatcompletion_processor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,24 @@ import (
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

func TestChatCompletion_SelectTranslator(t *testing.T) {
c := &chatCompletionProcessor{}
t.Run("unsupported", func(t *testing.T) {
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: "Bar", Version: "v123"})
require.ErrorContains(t, err, "unsupported API schema: backend={Bar v123}")
})
t.Run("supported openai", func(t *testing.T) {
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI})
require.NoError(t, err)
require.NotNil(t, c.translator)
})
t.Run("supported aws bedrock", func(t *testing.T) {
err := c.selectTranslator(filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock})
require.NoError(t, err)
require.NotNil(t, c.translator)
})
}

func TestChatCompletion_ProcessRequestHeaders(t *testing.T) {
p := &chatCompletionProcessor{}
res, err := p.ProcessRequestHeaders(t.Context(), &corev3.HeaderMap{
Expand Down Expand Up @@ -128,27 +146,9 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: make(map[filterapi.VersionedAPISchema]translator.Factory),
}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to find factory for output schema {\"some-schema\" \"v10.0\"}")
})
t.Run("translator factory error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
rbp := mockRequestBodyParser{t: t, retModelName: "some-model", expPath: "/foo"}
rt := mockRouter{
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
factory := mockTranslatorFactory{t: t, retErr: errors.New("test error"), expPath: "/foo"}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
},
}, requestHeaders: headers, logger: slog.Default()}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to create translator: test error")
require.ErrorContains(t, err, "unsupported API schema: backend={some-schema v10.0}")
})
t.Run("translator error", func(t *testing.T) {
headers := map[string]string{":path": "/foo"}
Expand All @@ -157,13 +157,10 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
t: t, expHeaders: headers, retBackendName: "some-backend",
retVersionedAPISchema: filterapi.VersionedAPISchema{Name: "some-schema", Version: "v10.0"},
}
factory := mockTranslatorFactory{t: t, retTranslator: mockTranslator{t: t, retErr: errors.New("test error")}, expPath: "/foo"}
tr := mockTranslator{t: t, retErr: errors.New("test error")}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
},
}, requestHeaders: headers, logger: slog.Default()}
}, requestHeaders: headers, logger: slog.Default(), translator: tr}
_, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.ErrorContains(t, err, "failed to transform request: test error")
})
Expand All @@ -178,15 +175,11 @@ func TestChatCompletion_ProcessRequestBody(t *testing.T) {
headerMut := &extprocv3.HeaderMutation{}
bodyMut := &extprocv3.BodyMutation{}
mt := mockTranslator{t: t, expRequestBody: someBody, retHeaderMutation: headerMut, retBodyMutation: bodyMut}
factory := mockTranslatorFactory{t: t, retTranslator: mt, expPath: "/foo"}
p := &chatCompletionProcessor{config: &processorConfig{
bodyParser: rbp.impl, router: rt,
factories: map[filterapi.VersionedAPISchema]translator.Factory{
{Name: "some-schema", Version: "v10.0"}: factory.impl,
},
selectedBackendHeaderKey: "x-ai-gateway-backend-key",
modelNameHeaderKey: "x-ai-gateway-model-key",
}, requestHeaders: headers, logger: slog.Default()}
}, requestHeaders: headers, logger: slog.Default(), translator: mt}
resp, err := p.ProcessRequestBody(t.Context(), &extprocv3.HttpBody{})
require.NoError(t, err)
require.Equal(t, mt, p.translator)
Expand Down
14 changes: 0 additions & 14 deletions internal/extproc/mocks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,6 @@ func (m *mockRequestBodyParser) impl(path string, body *extprocv3.HttpBody) (mod
return m.retModelName, m.retRb, m.retErr
}

// mockTranslatorFactory implements [translator.Factory] for testing.
type mockTranslatorFactory struct {
t *testing.T
expPath string
retTranslator translator.Translator
retErr error
}

// NewTranslator implements [translator.Factory].
func (m mockTranslatorFactory) impl(path string) (translator.Translator, error) {
require.Equal(m.t, m.expPath, path)
return m.retTranslator, m.retErr
}

// mockExternalProcessingStream implements [extprocv3.ExternalProcessor_ProcessServer] for testing.
type mockExternalProcessingStream struct {
t *testing.T
Expand Down
2 changes: 0 additions & 2 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/envoyproxy/ai-gateway/filterapi/x"
"github.com/envoyproxy/ai-gateway/internal/extproc/backendauth"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
)

// processorConfig is the configuration for the processor.
Expand All @@ -22,7 +21,6 @@ type processorConfig struct {
bodyParser router.RequestBodyParser
router x.Router
modelNameHeaderKey, selectedBackendHeaderKey string
factories map[filterapi.VersionedAPISchema]translator.Factory
backendAuthHandlers map[string]backendauth.Handler
metadataNamespace string
requestCosts []processorConfigRequestCost
Expand Down
10 changes: 0 additions & 10 deletions internal/extproc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/envoyproxy/ai-gateway/filterapi/x"
"github.com/envoyproxy/ai-gateway/internal/extproc/backendauth"
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
"github.com/envoyproxy/ai-gateway/internal/extproc/translator"
"github.com/envoyproxy/ai-gateway/internal/llmcostcel"
)

Expand Down Expand Up @@ -60,19 +59,11 @@ func (s *Server) LoadConfig(ctx context.Context, config *filterapi.Config) error
}

var (
factories = make(map[filterapi.VersionedAPISchema]translator.Factory)
backendAuthHandlers = make(map[string]backendauth.Handler)
declaredModels []string
)
for _, r := range config.Rules {
for _, b := range r.Backends {
if _, ok := factories[b.Schema]; !ok {
factories[b.Schema], err = translator.NewFactory(config.Schema, b.Schema)
if err != nil {
return fmt.Errorf("cannot create translator factory: %w", err)
}
}

if b.Auth != nil {
backendAuthHandlers[b.Name], err = backendauth.NewHandler(ctx, b.Auth)
if err != nil {
Expand Down Expand Up @@ -112,7 +103,6 @@ func (s *Server) LoadConfig(ctx context.Context, config *filterapi.Config) error
bodyParser: bodyParser, router: rt,
selectedBackendHeaderKey: config.SelectedBackendHeaderKey,
modelNameHeaderKey: config.ModelNameHeaderKey,
factories: factories,
backendAuthHandlers: backendAuthHandlers,
metadataNamespace: config.MetadataNamespace,
requestCosts: costs,
Expand Down
3 changes: 0 additions & 3 deletions internal/extproc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ func TestServer_LoadConfig(t *testing.T) {
require.NotNil(t, s.config.bodyParser)
require.Equal(t, "x-ai-eg-selected-backend", s.config.selectedBackendHeaderKey)
require.Equal(t, "x-model-name", s.config.modelNameHeaderKey)
require.Len(t, s.config.factories, 2)
require.NotNil(t, s.config.factories[filterapi.VersionedAPISchema{Name: filterapi.APISchemaOpenAI}])
require.NotNil(t, s.config.factories[filterapi.VersionedAPISchema{Name: filterapi.APISchemaAWSBedrock}])

require.Len(t, s.config.requestCosts, 2)
require.Equal(t, filterapi.LLMRequestCostTypeOutputToken, s.config.requestCosts[0].Type)
Expand Down
Loading

0 comments on commit f49dc98

Please sign in to comment.