diff --git a/typescript/src/lmp/_track.ts b/typescript/src/lmp/_track.ts index 9baea1378..a2ad8d136 100644 --- a/typescript/src/lmp/_track.ts +++ b/typescript/src/lmp/_track.ts @@ -46,7 +46,7 @@ export type Invocation = { -type F = (...args: any[]) => Promise> +type F = (...args: any[]) => Promise> | string | Array /** * Used for tracing of invocations. @@ -303,7 +303,7 @@ export const invokeWithTracking = async (lmp: LMPDefinition & { lmpId: string }, } const start = performance.now() - const lmpfnoutput = await f(...args) + const lmpfnoutput = await Promise.resolve(f(...args)) // const event = await getNextPausedEvent() // console.log('event', event) // await handleBreakpointHit(event) diff --git a/typescript/src/lmp/simple.ts b/typescript/src/lmp/simple.ts index fb2d5dcb0..f0f4a37d5 100644 --- a/typescript/src/lmp/simple.ts +++ b/typescript/src/lmp/simple.ts @@ -10,7 +10,7 @@ import { EllCallParams } from '../provider' const logger = logging.getLogger('ell') -type SimpleLMPInner = (...args: any[]) => Promise> +type SimpleLMPInner = (...args: any[]) => string | Array | Promise> type SimpleLMP = ((...args: Parameters) => Promise) & { __ell_type__?: 'simple' __ell_lmp_name__?: string @@ -49,7 +49,7 @@ export const simple = (a: Kwargs, f: F): SimpleLMP if (lmpId && !a.exempt_from_tracking) { return await invokeWithTracking({ ...lmpDefinition!, lmpId }, args, f, a) } - const promptFnOutput = await f(...args) + const promptFnOutput = await Promise.resolve(f(...args)) const modelClient = await getModelClient(a) const provider = config.getProviderFor(modelClient) if (!provider) { diff --git a/typescript/src/serialize/sql.ts b/typescript/src/serialize/sql.ts index 1f5206cda..84f56292f 100644 --- a/typescript/src/serialize/sql.ts +++ b/typescript/src/serialize/sql.ts @@ -78,6 +78,7 @@ export type InvocationContents = { free_vars: Record is_external: boolean invocation: Invocation + // todo. created_at? } export const InvocationContents = (props: InvocationContents) => ({ ...props, @@ -138,7 +139,7 @@ class Mutex { } export class SQLiteStore extends Store { - private db: Database | null = null + public db: Database | null = null private dbPath: string private txMutex = new Mutex() diff --git a/typescript/test/fixtures/hello_world.ts b/typescript/test/fixtures/hello_world.ts index e83c6999c..57b088704 100644 --- a/typescript/test/fixtures/hello_world.ts +++ b/typescript/test/fixtures/hello_world.ts @@ -14,7 +14,7 @@ function getRandomPunctuation(): string { return randomChoice(['!', '!!', '!!!']) } -export const hello = ell.simple({ model: 'gpt-4o-mini' }, async (name: string) => { +export const hello = ell.simple({ model: 'gpt-4o-mini' }, (name: string) => { const adjective = getRandomAdjective() const punctuation = getRandomPunctuation() diff --git a/typescript/test/runtime.mocha.ts b/typescript/test/runtime.mocha.ts index 4e359e938..f4b7740b2 100644 --- a/typescript/test/runtime.mocha.ts +++ b/typescript/test/runtime.mocha.ts @@ -68,4 +68,24 @@ describe('lmp', () => { const result = await child2('world') assert.deepStrictEqual(result, new Message('assistant', 'child')) }) + + test('sync prompt functions', async () => { + const child = simple({ model: 'gpt-4o-mini' }, (a: string) => { + return 'child' + }) + const hello = simple({ model: 'gpt-4o' }, async (a: { a: string }) => { + const ok = await child(a.a) + return a.a + ok + }) + + const result = await hello({ a: 'world' }) + + assert.equal(result, 'worldchild') + + assert.ok(hello.__ell_lmp_id__?.startsWith('lmp-')) + assert.equal(hello.__ell_lmp_name__, 'hello') + + assert.ok(child.__ell_lmp_id__?.startsWith('lmp-')) + assert.equal(child.__ell_lmp_name__, 'child') + }) }) diff --git a/typescript/test/tracing.mocha.ts b/typescript/test/tracing.mocha.ts new file mode 100644 index 000000000..1dddd494f --- /dev/null +++ b/typescript/test/tracing.mocha.ts @@ -0,0 +1,115 @@ +import assert from 'assert' +import OpenAI from 'openai' +import { config } from '../src/configurator' +import { chatCompletionsToStream } from './util' +import { SQLiteStore } from '../src/serialize/sql' +import * as ell from 'ell-ai' + +describe('tracing', () => { + let store: SQLiteStore + beforeEach(async () => { + store = new SQLiteStore(':memory:') + await store.initialize() + ell.init({ store }) + + config.defaultClient = config.defaultClient || new OpenAI({ apiKey: 'test' }) + // @ts-expect-error + config.defaultClient.chat.completions.create = async (...args) => { + return chatCompletionsToStream([ + { + usage: { + prompt_tokens: 10, + completion_tokens: 10, + latency_ms: 10, + total_tokens: 20, + }, + id: 'chatcmpl-123', + created: 1677652288, + model: 'gpt-3.5-turbo-0125', + object: 'chat.completion', + choices: [ + { + index: 0, + finish_reason: 'stop', + logprobs: null, + message: { + // @ts-expect-error + content: args[0].messages[0].content[0].text, + role: 'assistant', + refusal: null, + }, + }, + ], + }, + ]) + } + }) + + it('simple', async () => { + const hello = require('./fixtures/hello_world') + const result = await hello.hello('world') + + assert.equal(result, 'You are a helpful and expressive assistant.') + + const lmp = (await store.db?.all('SELECT * FROM serializedlmp'))?.[0] + assert.ok(typeof lmp.created_at === 'string') + delete lmp.created_at + assert.deepEqual(lmp, { + lmp_id: 'lmp-a79d4140040f36d6c8074901fd00d769', + name: 'test.fixtures.hello_world.hello', + source: + 'export const hello = ell.simple({ model: \'gpt-4o-mini\' }, (name: string) => {\n const adjective = getRandomAdjective()\n const punctuation = getRandomPunctuation()\n\n return [\n ell.system(\'You are a helpful and expressive assistant.\'),\n ell.user(`Say a ${adjective} hello to ${name}${punctuation}`),\n ] \n})', + language: 'typescript', + dependencies: '', + lmp_type: 'LM', + api_params: '{"model":"gpt-4o-mini"}', + initial_free_vars: '{}', + initial_global_vars: '{}', + num_invocations: 1, + commit_message: 'Initial version', + version_number: 1, + }) + + const invocation = (await store.db?.all('SELECT * FROM invocation'))?.[0] + const invocationId = invocation.id + + assert.ok(invocationId.startsWith('invocation-')) + delete invocation.id + assert.ok(typeof invocation.created_at === 'string') + delete invocation.created_at + assert.ok(typeof invocation.latency_ms === 'number') + delete invocation.latency_ms + + assert.deepStrictEqual(invocation, { + lmp_id: lmp.lmp_id, + prompt_tokens: null, + completion_tokens: null, + state_cache_key: '', + used_by_id: null, + }) + + const invocationContents = (await store.db?.all('SELECT * FROM invocationcontents'))?.[0] + + assert.equal(invocationContents?.invocation_id, invocationId) + delete invocationContents.invocation_id + + // Free vars + const freeVars = JSON.parse(invocationContents.free_vars) + assert.deepEqual(freeVars.name, 'world') + assert.ok(['enthusiastic', 'cheerful', 'warm', 'friendly', 'heartfelt', 'sincere'].includes(freeVars.adjective)) + assert.ok(['!', '!!', '!!!'].includes(freeVars.punctuation)) + delete invocationContents.free_vars + + // Global vars + const globalVars = JSON.parse(invocationContents.global_vars) + assert.deepEqual(globalVars, {}) + delete invocationContents.global_vars + + assert.deepStrictEqual(invocationContents, { + 'invocation_api_params': '{"model":"gpt-4o-mini"}', + 'is_external': 0, + 'params': '["world"]', + 'results': '"You are a helpful and expressive assistant."', + }) + }) +})