From fe8a957946b4eed2d9d2dd6be1e6cc519f4133e5 Mon Sep 17 00:00:00 2001 From: Jacob Cable <32874567+cabljac@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:08:58 +0100 Subject: [PATCH] feat(js): add support for streaming json output (#484) * feat(js): add support for streaming json output * refactor: switch to partial-json library * refactor: merge the two extract methods * chore: update lockfile --- CONTRIBUTING.md | 2 +- js/ai/package.json | 4 +- js/ai/src/extract.ts | 36 +++++++++++-- js/ai/src/generate.ts | 37 +++++++++++-- js/ai/tests/generate/generate_test.ts | 75 +++++++++++++++++++++++++++ js/pnpm-lock.yaml | 3 ++ 6 files changed, 146 insertions(+), 11 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 559e302ea..4ca7fba8e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -131,7 +131,7 @@ genkit eval:flow pdfQA '"What's a brief description of MapReduce?"' FYI: `js` and `genkit-tools` are in two separate workspaces. -As you make changes you may want to build an test things by running test apps. +As you make changes you may want to build and test things by running test apps. You can reduce the scope of what you're building by running a specific build command: ``` diff --git a/js/ai/package.json b/js/ai/package.json index b186dbec4..611b986c6 100644 --- a/js/ai/package.json +++ b/js/ai/package.json @@ -15,7 +15,8 @@ "build:clean": "rm -rf ./lib", "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", - "test": "node --import tsx --test ./tests/**/*_test.ts" + "test": "node --import tsx --test ./tests/**/*_test.ts", + "test:single": "node --import tsx --test" }, "repository": { "type": "git", @@ -30,6 +31,7 @@ "@types/node": "^20.11.19", "json5": "^2.2.3", "node-fetch": "^3.3.2", + "partial-json": "^0.1.7", "zod": "^3.22.4" }, "devDependencies": { diff --git a/js/ai/src/extract.ts b/js/ai/src/extract.ts index 8996af148..402587ed2 100644 --- a/js/ai/src/extract.ts +++ b/js/ai/src/extract.ts @@ -15,11 +15,27 @@ */ import JSON5 from 'json5'; +import { Allow, parse } from 'partial-json'; + +export function parsePartialJson(jsonString: string): T { + return JSON5.parse(JSON.stringify(parse(jsonString, Allow.ALL))); +} /** * Extracts JSON from string with lenient parsing rules to improve likelihood of successful extraction. */ -export function extractJson(text: string): T | null { +export function extractJson( + text: string, + throwOnBadJson?: true +): T; +export function extractJson( + text: string, + throwOnBadJson?: false +): T | null; +export function extractJson( + text: string, + throwOnBadJson?: boolean +): T | null { let openingChar: '{' | '[' | undefined; let closingChar: '}' | ']' | undefined; let startPos: number | undefined; @@ -48,11 +64,21 @@ export function extractJson(text: string): T | null { } if (startPos !== undefined && nestingCount > 0) { + // If an incomplete JSON structure is detected try { - return JSON5.parse(text.substring(startPos) + (closingChar || '')) as T; - } catch (e) { - throw new Error(`Invalid JSON extracted from model output: ${text}`); + // Parse the incomplete JSON structure using partial-json for lenient parsing + // Note: partial-json automatically handles adding the closing character + return parsePartialJson(text.substring(startPos)); + } catch { + // If parsing fails, throw an error + if (throwOnBadJson) { + throw new Error(`Invalid JSON extracted from model output: ${text}`); + } + return null; // Return null if no JSON structure is found } } } - throw new Error(`No JSON object or array found in model output: ${text}`); + if (throwOnBadJson) { + throw new Error(`Invalid JSON extracted from model output: ${text}`); + } + return null; // Return null if no JSON structure is found } diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index 1bc28e2e7..ebda1c322 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -74,7 +74,7 @@ export class Message implements MessageData { * * @returns The structured output contained in the message. */ - output(): T | null { + output(): T { return this.data() || extractJson(this.text()); } @@ -360,11 +360,17 @@ export class GenerateResponseChunk content: Part[]; /** Custom model-specific data for this chunk. */ custom?: unknown; + /** Accumulated chunks for partial output extraction. */ + accumulatedChunks?: GenerateResponseChunkData[]; - constructor(data: GenerateResponseChunkData) { + constructor( + data: GenerateResponseChunkData, + accumulatedChunks?: GenerateResponseChunkData[] + ) { this.index = data.index; this.content = data.content || []; this.custom = data.custom; + this.accumulatedChunks = accumulatedChunks; } /** @@ -402,6 +408,18 @@ export class GenerateResponseChunk ) as ToolRequestPart[]; } + /** + * Attempts to extract the longest valid JSON substring from the accumulated chunks. + * @returns The longest valid JSON substring found in the accumulated chunks. + */ + output(): T | null { + if (!this.accumulatedChunks) return null; + const accumulatedText = this.accumulatedChunks + .map((chunk) => chunk.content.map((part) => part.text || '').join('')) + .join(''); + return extractJson(accumulatedText, false); + } + toJSON(): GenerateResponseChunkData { return { index: this.index, content: this.content, custom: this.custom }; } @@ -586,6 +604,7 @@ export class NoValidCandidatesError extends GenkitError { * @param options The options for this generation request. * @returns The generated response based on the provided parameters. */ + export async function generate< O extends z.ZodTypeAny = z.ZodTypeAny, CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema, @@ -612,10 +631,20 @@ export async function generate< } const request = await toGenerateRequest(resolvedOptions); + + const accumulatedChunks: GenerateResponseChunkData[] = []; + const response = await runWithStreamingCallback( resolvedOptions.streamingCallback - ? (chunk: GenerateResponseChunkData) => - resolvedOptions.streamingCallback!(new GenerateResponseChunk(chunk)) + ? (chunk: GenerateResponseChunkData) => { + // Store accumulated chunk data + accumulatedChunks.push(chunk); + if (resolvedOptions.streamingCallback) { + resolvedOptions.streamingCallback!( + new GenerateResponseChunk(chunk, accumulatedChunks) + ); + } + } : undefined, async () => new GenerateResponse>(await model(request), request) ); diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index f8717f9f8..2388c401b 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -17,6 +17,7 @@ import assert from 'node:assert'; import { describe, it } from 'node:test'; import { z } from 'zod'; +import { GenerateResponseChunk } from '../../src/generate'; import { Candidate, GenerateOptions, @@ -24,6 +25,7 @@ import { Message, toGenerateRequest, } from '../../src/generate.js'; +import { GenerateResponseChunkData } from '../../src/model'; import { CandidateData, GenerateRequest, @@ -506,3 +508,76 @@ describe('toGenerateRequest', () => { }); } }); + +describe('GenerateResponseChunk', () => { + describe('#output()', () => { + const testCases = [ + { + should: 'parse ``` correctly', + accumulatedChunksTexts: ['```'], + correctJson: null, + }, + { + should: 'parse valid json correctly', + accumulatedChunksTexts: [`{"foo":"bar"}`], + correctJson: { foo: 'bar' }, + }, + { + should: 'if json invalid, return null', + accumulatedChunksTexts: [`invalid json`], + correctJson: null, + }, + { + should: 'handle missing closing brace', + accumulatedChunksTexts: [`{"foo":"bar"`], + correctJson: { foo: 'bar' }, + }, + { + should: 'handle missing closing bracket in nested object', + accumulatedChunksTexts: [`{"foo": {"bar": "baz"`], + correctJson: { foo: { bar: 'baz' } }, + }, + { + should: 'handle multiple chunks', + accumulatedChunksTexts: [`{"foo": {"bar"`, `: "baz`], + correctJson: { foo: { bar: 'baz' } }, + }, + { + should: 'handle multiple chunks with nested objects', + accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: {"baz": "qux`], + correctJson: { foo: { bar: { baz: 'qux' } } }, + }, + { + should: 'handle array nested in object', + accumulatedChunksTexts: [`{"foo": ["bar`], + correctJson: { foo: ['bar'] }, + }, + { + should: 'handle array nested in object with multiple chunks', + accumulatedChunksTexts: [`\`\`\`json{"foo": {"bar"`, `: ["baz`], + correctJson: { foo: { bar: ['baz'] } }, + }, + ]; + + for (const test of testCases) { + if (test.should) { + it(test.should, () => { + const accumulatedChunks: GenerateResponseChunkData[] = + test.accumulatedChunksTexts.map((text, index) => ({ + index, + content: [{ text }], + })); + + const chunkData = accumulatedChunks[accumulatedChunks.length - 1]; + + const responseChunk: GenerateResponseChunk = + new GenerateResponseChunk(chunkData, accumulatedChunks); + + const output = responseChunk.output(); + + assert.deepStrictEqual(output, test.correctJson); + }); + } + } + }); +}); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 52ea1b575..c3224e83a 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -35,6 +35,9 @@ importers: node-fetch: specifier: ^3.3.2 version: 3.3.2 + partial-json: + specifier: ^0.1.7 + version: 0.1.7 zod: specifier: ^3.22.4 version: 3.22.4