Skip to content

Commit

Permalink
feat: Add support for tool calls with the Vercel AI sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
Granipouss committed May 13, 2024
1 parent 569a63b commit 44c4268
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 15 deletions.
32 changes: 21 additions & 11 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
"tsup": "^8.0.1",
"typedoc": "^0.25.13",
"typedoc-plugin-markdown": "^4.0.0-next.25",
"typescript": "^5.3.3"
"typescript": "^5.3.3",
"zod": "^3.23.8",
"zod-to-json-schema": "^3.23.0"
},
"dependencies": {
"axios": "^1.6.2",
Expand All @@ -57,6 +59,7 @@
"@ai-sdk/openai": "^0.0.9",
"ai": "^3.1.0",
"langchain": "^0.1.14",
"openai": "^4.26.0"
"openai": "^4.26.0",
"zod-to-json-schema": "^3.23.0"
}
}
34 changes: 32 additions & 2 deletions src/instrumentation/vercel-sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type {
streamObject,
streamText
} from 'ai';
import { zodToJsonSchema } from 'zod-to-json-schema';

import {
ChatGeneration,
Expand Down Expand Up @@ -61,6 +62,17 @@ const extractSettings = (options: Options<AllVercelFn>): ILLMSettings => {
delete settings.model;
delete settings.prompt;
delete settings.abortSignal;
if ('tools' in settings) {
settings.tools = Object.fromEntries(
Object.entries<CoreTool>(settings.tools).map(([key, tool]) => [
key,
{
description: tool.description,
parameters: zodToJsonSchema(tool.parameters)
}
])
);
}
return settings;
};

Expand All @@ -80,12 +92,20 @@ const computeMetricsSync = (
const completion =
'text' in result ? result.text : JSON.stringify(result.object);

const messageCompletion: IGenerationMessage = {
role: 'assistant',
content: completion
};
if ('toolCalls' in result) {
messageCompletion.tool_calls = result.toolResults;
}

return {
duration,
tokenThroughputInSeconds,
outputTokenCount,
inputTokenCount,
messageCompletion: { role: 'assistant', content: completion }
messageCompletion
};
};

Expand All @@ -111,7 +131,17 @@ const computeMetricsStream = async (
}
case 'tool-call':
case 'tool-result': {
// TODO: Handle
messageCompletion.tool_calls = messageCompletion.tool_calls ?? [];
const index = messageCompletion.tool_calls.findIndex(
(call) => call.toolCallId === chunk.toolCallId
);
if (index === -1) {
// Insert tool call
messageCompletion.tool_calls.push(chunk);
} else {
// Replace the tool call with the result
messageCompletion.tool_calls[index] = chunk;
}
break;
}
}
Expand Down
119 changes: 119 additions & 0 deletions tests/integration/vercel-sdk.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { openai } from '@ai-sdk/openai';
import { generateText, streamText } from 'ai';
import { z } from 'zod';

import { LiteralClient } from '../../src';

Expand Down Expand Up @@ -142,5 +143,123 @@ describe('Vercel SDK Instrumentation', () => {
})
]);
});

it('should monitor tools', async () => {
const spy = jest.spyOn(client.api, 'createGeneration');

const generateTextWithLiteralAI =
client.instrumentation.vercel.instrument(generateText);

const { text, toolResults } = await generateTextWithLiteralAI({
model,
system: 'You are a friendly assistant!',
messages: [{ role: 'user', content: 'Convert 20°C to Fahrenheit' }],
tools: {
celsiusToFahrenheit: {
description: 'Converts celsius to fahrenheit',
parameters: z.object({
value: z.number().describe('The value in celsius')
}),
execute: async ({ value }) => {
const celsius = parseFloat(value);
const fahrenheit = celsius * (9 / 5) + 32;
return fahrenheit;
}
}
}
});

expect(text).toBe('');
expect(toolResults).toEqual([
{
toolCallId: expect.any(String),
toolName: 'celsiusToFahrenheit',
args: { value: 20 },
result: 68
}
]);

expect(spy).toHaveBeenCalledWith(
expect.objectContaining({
provider: 'openai.chat',
model: 'gpt-3.5-turbo',
messages: [
{ role: 'system', content: 'You are a friendly assistant!' },
{ role: 'user', content: 'Convert 20°C to Fahrenheit' }
],
messageCompletion: {
role: 'assistant',
content: '',
tool_calls: toolResults
},
duration: expect.any(Number)
})
);
});

it('should monitor tools in streams', async () => {
const spy = jest.spyOn(client.api, 'createGeneration');

const streamTextWithLiteralAI =
client.instrumentation.vercel.instrument(streamText);

const result = await streamTextWithLiteralAI({
model,
system: 'You are a friendly assistant!',
messages: [{ role: 'user', content: 'Convert 20°C to Fahrenheit' }],
tools: {
celsiusToFahrenheit: {
description: 'Converts celsius to fahrenheit',
parameters: z.object({
value: z.number().describe('The value in celsius')
}),
execute: async ({ value }) => {
const celsius = parseFloat(value);
const fahrenheit = celsius * (9 / 5) + 32;
return fahrenheit;
}
}
}
});

// use textStream as an async iterable:
const chunks = [];
let toolCall, toolResult;
for await (const chunk of result.fullStream) {
chunks.push(chunk);
if (chunk.type === 'tool-call') {
toolCall = chunk;
}
if (chunk.type === 'tool-result') {
toolResult = chunk;
}
}

expect(toolCall!.toolCallId).toEqual(toolResult!.toolCallId);
expect(toolResult).toEqual({
type: 'tool-result',
toolCallId: expect.any(String),
toolName: 'celsiusToFahrenheit',
args: { value: 20 },
result: 68
});

expect(spy).toHaveBeenCalledWith(
expect.objectContaining({
provider: 'openai.chat',
model: 'gpt-3.5-turbo',
messages: [
{ role: 'system', content: 'You are a friendly assistant!' },
{ role: 'user', content: 'Convert 20°C to Fahrenheit' }
],
messageCompletion: {
role: 'assistant',
content: '',
tool_calls: [toolResult]
},
duration: expect.any(Number)
})
);
});
});
});

0 comments on commit 44c4268

Please sign in to comment.