Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for structured generation for vercel ai sdk #24

Merged
merged 2 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions src/instrumentation/vercel-sdk.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ const extractSettings = (options: Options<AllVercelFn>): ILLMSettings => {
])
);
}
if ('schema' in settings) {
settings.schema = zodToJsonSchema(settings.schema);
}
return settings;
};

Expand Down Expand Up @@ -224,27 +227,31 @@ const computeMetricsStream = async (
};
};

type ExtendedFunction<T extends (...args: any[]) => any> = (
options: Parameters<T>[0] & {
literalAiParent?: Step | Thread;
}
) => ReturnType<T>;

export const makeInstrumentVercelSDK = (client: LiteralClient) => {
function instrumentVercelSDK<T>(
fn: typeof streamObject<T>
): ExtendedFunction<typeof streamObject<T>>;
function instrumentVercelSDK<
TOOLS extends Record<string, CoreTool<any, any>>
>(fn: typeof streamText<TOOLS>): ExtendedFunction<typeof streamText<TOOLS>>;
function instrumentVercelSDK<T>(
fn: typeof generateObject<T>
): ExtendedFunction<typeof generateObject<T>>;
function instrumentVercelSDK<
TOOLS extends Record<string, CoreTool<any, any>>
>(
fn: typeof generateText<TOOLS>
): ExtendedFunction<typeof generateText<TOOLS>>;
type VercelExtraOptions = {
literalAiParent?: Step | Thread;
};

export type InstrumentationVercelMethod = {
(fn: typeof streamObject): <T>(
options: Parameters<typeof streamObject<T>>[0] & VercelExtraOptions
) => ReturnType<typeof streamObject<T>>;

(fn: typeof streamText): <TOOLS extends Record<string, CoreTool<any, any>>>(
options: Parameters<typeof streamText<TOOLS>>[0] & VercelExtraOptions
) => ReturnType<typeof streamText<TOOLS>>;

(fn: typeof generateObject): <T>(
options: Parameters<typeof generateObject<T>>[0] & VercelExtraOptions
) => ReturnType<typeof generateObject<T>>;

(fn: typeof generateText): <TOOLS extends Record<string, CoreTool<any, any>>>(
options: Parameters<typeof generateText<TOOLS>>[0] & VercelExtraOptions
) => ReturnType<typeof generateText<TOOLS>>;
};

export const makeInstrumentVercelSDK = (
client: LiteralClient
): InstrumentationVercelMethod => {
function instrumentVercelSDK<TFunction extends AllVercelFn>(fn: TFunction) {
type TOptions = Options<TFunction>;
type TResult = Result<TFunction>;
Expand Down
125 changes: 113 additions & 12 deletions tests/integration/vercel-sdk.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { openai } from '@ai-sdk/openai';
import { generateText, streamText } from 'ai';
import { generateObject, generateText, streamObject, streamText } from 'ai';
import { z } from 'zod';

import { LiteralClient } from '../../src';
Expand All @@ -20,21 +20,16 @@ describe('Vercel SDK Instrumentation', () => {

// Skip for the CI
describe.skip('With OpenAI', () => {
let model: ReturnType<typeof openai>;
beforeEach(() => {
model = openai('gpt-3.5-turbo');
});

afterEach(() => jest.restoreAllMocks());

it('should work a simple generation', async () => {
it('should work a simple text generation', async () => {
const spy = jest.spyOn(client.api, 'createGeneration');

const generateTextWithLiteralAI =
client.instrumentation.vercel.instrument(generateText);

const result = await generateTextWithLiteralAI({
model,
model: openai('gpt-3.5-turbo'),
prompt: 'Write a vegetarian lasagna recipe for 4 people.'
});

Expand Down Expand Up @@ -68,7 +63,7 @@ describe('Vercel SDK Instrumentation', () => {
client.instrumentation.vercel.instrument(streamText);

const result = await streamTextWithLiteralAI({
model,
model: openai('gpt-3.5-turbo'),
prompt: 'Write a strawberry tiramisu recipe for 4 people.'
});

Expand Down Expand Up @@ -101,6 +96,112 @@ describe('Vercel SDK Instrumentation', () => {
);
});

it('should work on structured generation', async () => {
const spy = jest.spyOn(client.api, 'createGeneration');

const generateObjectWithLiteralAI =
client.instrumentation.vercel.instrument(generateObject);

const result = await generateObjectWithLiteralAI({
model: openai('gpt-4'),
schema: z.object({
recipe: z.object({
name: z.string(),
ingredients: z.array(
z.object({
name: z.string(),
amount: z.string()
})
),
steps: z.array(z.string())
})
}),
prompt: 'Generate a carrot cake recipe.'
});

console.log({ result });

expect(result.object).toBeTruthy();

expect(spy).toHaveBeenCalledWith(
expect.objectContaining({
provider: 'openai.chat',
model: 'gpt-4',
messages: [
{
role: 'user',
content: [
{
type: 'text',
text: 'Generate a carrot cake recipe.'
}
]
}
],
messageCompletion: {
role: 'assistant',
content: JSON.stringify(result.object)
},
duration: expect.any(Number)
})
);
});

it('should work for streamed structured generation', async () => {
const spy = jest.spyOn(client.api, 'createGeneration');

const streamObjectWithLiteralAI =
client.instrumentation.vercel.instrument(streamObject);

const result = await streamObjectWithLiteralAI({
model: openai('gpt-4'),
schema: z.object({
recipe: z.object({
name: z.string(),
ingredients: z.array(
z.object({
name: z.string(),
amount: z.string()
})
),
steps: z.array(z.string())
})
}),
prompt: 'Generate a cheese cake recipe.'
});

let lastObject;
// use partialObjectStream as an async iterable:
for await (const part of result.partialObjectStream) {
lastObject = part;
}

expect(lastObject).toBeTruthy();

expect(spy).toHaveBeenCalledWith(
expect.objectContaining({
provider: 'openai.chat',
model: 'gpt-4',
messages: [
{
role: 'user',
content: [
{
type: 'text',
text: 'Generate a cheese cake recipe.'
}
]
}
],
messageCompletion: {
role: 'assistant',
content: expect.any(String)
},
duration: expect.any(Number)
})
);
});

it('should observe on a given thread', async () => {
const spy = jest.spyOn(client.api, 'sendSteps');

Expand All @@ -110,7 +211,7 @@ describe('Vercel SDK Instrumentation', () => {
client.instrumentation.vercel.instrument(generateText);

const result = await generateTextWithLiteralAI({
model,
model: openai('gpt-3.5-turbo'),
prompt: 'Write a vegetarian lasagna recipe for 4 people.',
literalAiParent: thread
});
Expand Down Expand Up @@ -151,7 +252,7 @@ describe('Vercel SDK Instrumentation', () => {
client.instrumentation.vercel.instrument(generateText);

const { text, toolResults } = await generateTextWithLiteralAI({
model,
model: openai('gpt-3.5-turbo'),
system: 'You are a friendly assistant!',
messages: [{ role: 'user', content: 'Convert 20°C to Fahrenheit' }],
tools: {
Expand Down Expand Up @@ -239,7 +340,7 @@ describe('Vercel SDK Instrumentation', () => {
client.instrumentation.vercel.instrument(streamText);

const result = await streamTextWithLiteralAI({
model,
model: openai('gpt-3.5-turbo'),
system: 'You are a friendly assistant!',
messages: [{ role: 'user', content: 'Convert 20°C to Fahrenheit' }],
tools: {
Expand Down
Loading