From 9a9ba19b1909d68b4c22fd3f47305a316301882e Mon Sep 17 00:00:00 2001 From: Jeason Date: Thu, 4 Jul 2024 21:29:35 +0800 Subject: [PATCH 1/2] feat: add embedding model for client side embed --- package.json | 7 +- pnpm-lock.yaml | 135 ++++++++++++++---------------------- src/embedding-model.test.ts | 53 ++++++++++++++ src/embedding-model.ts | 97 ++++++++++++++++++++++++++ src/index.ts | 1 + src/language-model.test.ts | 2 +- src/language-model.ts | 3 + 7 files changed, 212 insertions(+), 86 deletions(-) create mode 100644 src/embedding-model.test.ts create mode 100644 src/embedding-model.ts diff --git a/package.json b/package.json index 54b95c1..a66140b 100644 --- a/package.json +++ b/package.json @@ -35,7 +35,8 @@ "access": "public" }, "dependencies": { - "@ai-sdk/provider": "^0.0.10", + "@ai-sdk/provider": "^0.0.11", + "@mediapipe/tasks-text": "^0.10.14", "debug": "^4.3.5" }, "devDependencies": { @@ -46,8 +47,8 @@ "@radix-ui/react-dropdown-menu": "^2.1.0", "@radix-ui/react-label": "^2.1.0", "@radix-ui/react-select": "^2.1.0", - "@radix-ui/react-slot": "^1.1.0", "@radix-ui/react-slider": "^1.2.0", + "@radix-ui/react-slot": "^1.1.0", "@radix-ui/react-tooltip": "^1.1.2", "@tailwindcss/typography": "^0.5.13", "@types/debug": "^4.1.12", @@ -55,7 +56,7 @@ "@types/react": "^18", "@types/react-dom": "^18", "@vitest/coverage-v8": "^1.6.0", - "ai": "^3.1.31", + "ai": "^3.2.16", "autoprefixer": "^10.4.19", "class-variance-authority": "^0.7.0", "clsx": "^2.1.1", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 85c848c..f5d8312 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -6,8 +6,11 @@ settings: dependencies: '@ai-sdk/provider': - specifier: ^0.0.10 - version: 0.0.10 + specifier: ^0.0.11 + version: 0.0.11 + '@mediapipe/tasks-text': + specifier: ^0.10.14 + version: 0.10.14 debug: specifier: ^4.3.5 version: 4.3.5 @@ -62,8 +65,8 @@ devDependencies: specifier: ^1.6.0 version: 1.6.0(vitest@1.6.0) ai: - specifier: ^3.1.31 - version: 3.1.31(react@18.3.1)(solid-js@1.8.17)(svelte@4.2.18)(vue@3.4.27)(zod@3.23.8) + specifier: ^3.2.16 + version: 3.2.16(react@18.3.1)(svelte@4.2.18)(vue@3.4.27)(zod@3.23.8) autoprefixer: specifier: ^10.4.19 version: 10.4.19(postcss@8.4.38) @@ -142,8 +145,8 @@ devDependencies: packages: - /@ai-sdk/provider-utils@0.0.13(zod@3.23.8): - resolution: {integrity: sha512-cB2dPm9flj+yin5sjBLFcXdW8sZtAXLE/OLKgz9uHpHM55s7mnwZrDGfO6ot/ukHTxDDJunZLW7qSjgK/u0F1g==} + /@ai-sdk/provider-utils@1.0.0(zod@3.23.8): + resolution: {integrity: sha512-Akq7MmGQII8xAuoVjJns/n/2BTUrF6qaXIj/3nEuXk/hPSdETlLWRSrjrTmLpte1VIPE5ecNzTALST+6nz47UQ==} engines: {node: '>=18'} peerDependencies: zod: ^3.0.0 @@ -151,38 +154,40 @@ packages: zod: optional: true dependencies: - '@ai-sdk/provider': 0.0.10 + '@ai-sdk/provider': 0.0.11 eventsource-parser: 1.1.2 nanoid: 3.3.6 secure-json-parse: 2.7.0 zod: 3.23.8 dev: true - /@ai-sdk/provider@0.0.10: - resolution: {integrity: sha512-NzkrtREQpHID1cTqY/C4CI30PVOaXWKYytDR2EcytmFgnP7Z6+CrGIA/YCnNhYAuUm6Nx+nGpRL/Hmyrv7NYzg==} + /@ai-sdk/provider@0.0.11: + resolution: {integrity: sha512-VTipPQ92Moa5Ovg/nZIc8yNoIFfukZjUHZcQMduJbiUh3CLQyrBAKTEV9AwjPy8wgVxj3+GZjon0yyOJKhfp5g==} engines: {node: '>=18'} dependencies: json-schema: 0.4.0 - /@ai-sdk/react@0.0.1(react@18.3.1)(zod@3.23.8): - resolution: {integrity: sha512-y6KXzxRR7vmAgDVnS/hnLPt3RztvWOisANBw47O1o1D2nDeUqTo8E/SNw2J8mzzlRInGaw40EREY8jEf9AcwWQ==} + /@ai-sdk/react@0.0.16(react@18.3.1)(zod@3.23.8): + resolution: {integrity: sha512-PUPjI4XB8or2m2NvRU8SBzGfSwjlJ19Mdde8LkeppFoNj++53kgM4BiniAsVRl8v8WNGZ55rrsLyY5g8h+gfeA==} engines: {node: '>=18'} peerDependencies: react: ^18 || ^19 + zod: ^3.0.0 peerDependenciesMeta: react: optional: true + zod: + optional: true dependencies: - '@ai-sdk/provider-utils': 0.0.13(zod@3.23.8) - '@ai-sdk/ui-utils': 0.0.1(zod@3.23.8) + '@ai-sdk/provider-utils': 1.0.0(zod@3.23.8) + '@ai-sdk/ui-utils': 0.0.9(zod@3.23.8) react: 18.3.1 swr: 2.2.0(react@18.3.1) - transitivePeerDependencies: - - zod + zod: 3.23.8 dev: true - /@ai-sdk/solid@0.0.1(solid-js@1.8.17)(zod@3.23.8): - resolution: {integrity: sha512-5WWdoqpemYW66rMZUYF4sbDtZfF96Vt8RtrzpLv+95ZUM1nY1elxAWpHCeOyYEjWJE5+eiKpUs6Jr5mP2/gz8Q==} + /@ai-sdk/solid@0.0.11(zod@3.23.8): + resolution: {integrity: sha512-8L4YoNNmDWmdnqtKnFdmaDZ+bIf1m160NXSPMEDhhWvp+t1SGMS/eLemuYEkDnlO18hhM/0IKX8lbQEyz7QYPQ==} engines: {node: '>=18'} peerDependencies: solid-js: ^1.7.7 @@ -190,16 +195,13 @@ packages: solid-js: optional: true dependencies: - '@ai-sdk/ui-utils': 0.0.1(zod@3.23.8) - solid-js: 1.8.17 - solid-swr-store: 0.10.7(solid-js@1.8.17)(swr-store@0.10.6) - swr-store: 0.10.6 + '@ai-sdk/ui-utils': 0.0.9(zod@3.23.8) transitivePeerDependencies: - zod dev: true - /@ai-sdk/svelte@0.0.1(svelte@4.2.18)(zod@3.23.8): - resolution: {integrity: sha512-bpjTLKOwdcXjJzboq15etT1hdnRI1ErPZweWSsu1/LJlEFzD1M0qpZQwWHwPquYkzeppXOgsLrUZ+9D2RoC47Q==} + /@ai-sdk/svelte@0.0.12(svelte@4.2.18)(zod@3.23.8): + resolution: {integrity: sha512-pSgIQhu0H2MRuoi/oj/5sq7UIK7Nm1oLRmZQ0tz2iQeqO2uo3Pe0si4n7lo8gb8gOMCyqJtqteb13A7rAlusfQ==} engines: {node: '>=18'} peerDependencies: svelte: ^3.0.0 || ^4.0.0 @@ -207,25 +209,30 @@ packages: svelte: optional: true dependencies: - '@ai-sdk/provider-utils': 0.0.13(zod@3.23.8) - '@ai-sdk/ui-utils': 0.0.1(zod@3.23.8) + '@ai-sdk/provider-utils': 1.0.0(zod@3.23.8) + '@ai-sdk/ui-utils': 0.0.9(zod@3.23.8) sswr: 2.1.0(svelte@4.2.18) svelte: 4.2.18 transitivePeerDependencies: - zod dev: true - /@ai-sdk/ui-utils@0.0.1(zod@3.23.8): - resolution: {integrity: sha512-zOr1zIw/EH4fEQvDKsqYG3wY7GW32h8Wrx0lQpSAP59UCA4zgHBH6ogF5oj7+LUuWjT6be9S0G3l/tEPyRyxEw==} + /@ai-sdk/ui-utils@0.0.9(zod@3.23.8): + resolution: {integrity: sha512-RdC68yG1abpFQgpm3Tcn4hMbRzpRj0BXbphhwSpMwHqPQu4c/n82tYYJvhGB+rRXs/qLftLBS1NtrhqEYSVZTg==} engines: {node: '>=18'} + peerDependencies: + zod: ^3.0.0 + peerDependenciesMeta: + zod: + optional: true dependencies: - '@ai-sdk/provider-utils': 0.0.13(zod@3.23.8) - transitivePeerDependencies: - - zod + '@ai-sdk/provider-utils': 1.0.0(zod@3.23.8) + secure-json-parse: 2.7.0 + zod: 3.23.8 dev: true - /@ai-sdk/vue@0.0.1(vue@3.4.27)(zod@3.23.8): - resolution: {integrity: sha512-B3qAW22FYGy1ltobnF7LiPAmARTrCkH15qjw4WAXCnvRohsYOFTDACOBEsXRfa1OHmqWsUOYeNtE/oPhK3ybqw==} + /@ai-sdk/vue@0.0.11(vue@3.4.27)(zod@3.23.8): + resolution: {integrity: sha512-YXqrFCIo8iOCsTBagEAAH6YIgveZCvS66Lm+WcyYVC5ehwx4Hn2vSayaRUiqQiHxDkF/IdETURRKki/cGbp/eg==} engines: {node: '>=18'} peerDependencies: vue: ^3.3.4 @@ -233,7 +240,7 @@ packages: vue: optional: true dependencies: - '@ai-sdk/ui-utils': 0.0.1(zod@3.23.8) + '@ai-sdk/ui-utils': 0.0.9(zod@3.23.8) swrv: 1.0.4(vue@3.4.27) vue: 3.4.27(typescript@5.4.5) transitivePeerDependencies: @@ -1199,6 +1206,10 @@ packages: - supports-color dev: true + /@mediapipe/tasks-text@0.10.14: + resolution: {integrity: sha512-hQU/t9df83lU083sjbX69ImipPOqfvdMnc0pMk44L0QkRXKjSp9q1OYVhA5ePOe1bUmEUzaYSPPTwz+mNhFqhw==} + dev: false + /@next/env@14.2.3: resolution: {integrity: sha512-W7fd7IbkfmeeY2gXrzJYDx8D2lWKbVoTIj1o1ScPHNzvp30s1AuoEFSdr39bC5sjxJaxTtq3OTCZboNp0lNWHA==} dev: true @@ -2720,8 +2731,8 @@ packages: - supports-color dev: true - /ai@3.1.31(react@18.3.1)(solid-js@1.8.17)(svelte@4.2.18)(vue@3.4.27)(zod@3.23.8): - resolution: {integrity: sha512-fnQz8qlBuJuImUZCydbn0bTFCZFRwHeVQI+wBbBkR5S/FrF09snt0YrgiWzDc0il4u1rerzVPEUiasOdoGaoWA==} + /ai@3.2.16(react@18.3.1)(svelte@4.2.18)(vue@3.4.27)(zod@3.23.8): + resolution: {integrity: sha512-kNqnmSQxUm3dcxLv/NoOusoMAq7EK/zahB/N//wkGoawgYqfOUpvz3+uS/FG52tZwMyy4YblVoeuBHpfPWzNDw==} engines: {node: '>=18'} peerDependencies: openai: ^4.42.0 @@ -2738,13 +2749,13 @@ packages: zod: optional: true dependencies: - '@ai-sdk/provider': 0.0.10 - '@ai-sdk/provider-utils': 0.0.13(zod@3.23.8) - '@ai-sdk/react': 0.0.1(react@18.3.1)(zod@3.23.8) - '@ai-sdk/solid': 0.0.1(solid-js@1.8.17)(zod@3.23.8) - '@ai-sdk/svelte': 0.0.1(svelte@4.2.18)(zod@3.23.8) - '@ai-sdk/ui-utils': 0.0.1(zod@3.23.8) - '@ai-sdk/vue': 0.0.1(vue@3.4.27)(zod@3.23.8) + '@ai-sdk/provider': 0.0.11 + '@ai-sdk/provider-utils': 1.0.0(zod@3.23.8) + '@ai-sdk/react': 0.0.16(react@18.3.1)(zod@3.23.8) + '@ai-sdk/solid': 0.0.11(zod@3.23.8) + '@ai-sdk/svelte': 0.0.12(svelte@4.2.18)(zod@3.23.8) + '@ai-sdk/ui-utils': 0.0.9(zod@3.23.8) + '@ai-sdk/vue': 0.0.11(vue@3.4.27)(zod@3.23.8) eventsource-parser: 1.1.2 json-schema: 0.4.0 jsondiffpatch: 0.6.0 @@ -7358,20 +7369,6 @@ packages: hasBin: true dev: true - /seroval-plugins@1.0.7(seroval@1.0.7): - resolution: {integrity: sha512-GO7TkWvodGp6buMEX9p7tNyIkbwlyuAWbI6G9Ec5bhcm7mQdu3JOK1IXbEUwb3FVzSc363GraG/wLW23NSavIw==} - engines: {node: '>=10'} - peerDependencies: - seroval: ^1.0 - dependencies: - seroval: 1.0.7 - dev: true - - /seroval@1.0.7: - resolution: {integrity: sha512-n6ZMQX5q0Vn19Zq7CIKNIo7E75gPkGCFUEqDpa8jgwpYr/vScjqnQ6H09t1uIiZ0ZSK0ypEGvrYK2bhBGWsGdw==} - engines: {node: '>=10'} - dev: true - /set-blocking@2.0.0: resolution: {integrity: sha512-KiKBS8AnWGEyLzofFfmvKwpdPzqiy16LvQfK3yv/fVH7Bj13/wl3JSR1J+rfgRE9q7xUJK4qvgS8raSOeLUehw==} dev: true @@ -7492,25 +7489,6 @@ packages: engines: {node: '>=8'} dev: true - /solid-js@1.8.17: - resolution: {integrity: sha512-E0FkUgv9sG/gEBWkHr/2XkBluHb1fkrHywUgA6o6XolPDCJ4g1HaLmQufcBBhiF36ee40q+HpG/vCZu7fLpI3Q==} - dependencies: - csstype: 3.1.3 - seroval: 1.0.7 - seroval-plugins: 1.0.7(seroval@1.0.7) - dev: true - - /solid-swr-store@0.10.7(solid-js@1.8.17)(swr-store@0.10.6): - resolution: {integrity: sha512-A6d68aJmRP471aWqKKPE2tpgOiR5fH4qXQNfKIec+Vap+MGQm3tvXlT8n0I8UgJSlNAsSAUuw2VTviH2h3Vv5g==} - engines: {node: '>=10'} - peerDependencies: - solid-js: ^1.2 - swr-store: ^0.10 - dependencies: - solid-js: 1.8.17 - swr-store: 0.10.6 - dev: true - /source-map-js@1.2.0: resolution: {integrity: sha512-itJW8lvSA0TXEphiRoawsCksnlf8SyvmFzIhltqAHluXd88pkCd+cXJVHTDwdCr0IzwptSm035IHQktUu1QUMg==} engines: {node: '>=0.10.0'} @@ -7785,13 +7763,6 @@ packages: periscopic: 3.1.0 dev: true - /swr-store@0.10.6: - resolution: {integrity: sha512-xPjB1hARSiRaNNlUQvWSVrG5SirCjk2TmaUyzzvk69SZQan9hCJqw/5rG9iL7xElHU784GxRPISClq4488/XVw==} - engines: {node: '>=10'} - dependencies: - dequal: 2.0.3 - dev: true - /swr@2.2.0(react@18.3.1): resolution: {integrity: sha512-AjqHOv2lAhkuUdIiBu9xbuettzAzWXmCEcLONNKJRba87WAefz8Ca9d6ds/SzrPc235n1IxWYdhJ2zF3MNUaoQ==} peerDependencies: diff --git a/src/embedding-model.test.ts b/src/embedding-model.test.ts new file mode 100644 index 0000000..3f9c7ca --- /dev/null +++ b/src/embedding-model.test.ts @@ -0,0 +1,53 @@ +import { describe, it, expect, vi } from 'vitest'; +import { ChromeAIEmbeddingModel, chromeEmbedding } from './embedding-model'; +import { embed } from 'ai'; + +vi.mock('@mediapipe/tasks-text', async () => ({ + FilesetResolver: { + forTextTasks: vi.fn(async () => ({ + wasmLoaderPath: 'wasmLoaderPath', + wasmBinaryPath: 'wasmBinaryPath', + })), + }, + TextEmbedder: { + createFromOptions: vi.fn(async () => ({ + embed: vi.fn((text: string) => ({ + embeddings: [ + { floatEmbedding: text === 'undefined' ? undefined : [1, 2] }, + ], + })), + })), + }, +})); + +describe('embedding-model', () => { + it('should instantiation anyways', async () => { + expect(new ChromeAIEmbeddingModel()).toBeInstanceOf(ChromeAIEmbeddingModel); + expect(chromeEmbedding()).toBeInstanceOf(ChromeAIEmbeddingModel); + }); + it('should embed', async () => { + const model = chromeEmbedding(); + expect( + await embed({ + model, + value: 'test', + }) + ).toMatchObject({ embedding: [1, 2] }); + + expect( + await embed({ + model, + value: 'test2', + }) + ).toMatchObject({ embedding: [1, 2] }); + }); + + it('should embed result empty', async () => { + expect( + await embed({ + model: chromeEmbedding({ l2Normalize: true }), + value: 'undefined', + }) + ).toMatchObject({ embedding: [] }); + }); +}); diff --git a/src/embedding-model.ts b/src/embedding-model.ts new file mode 100644 index 0000000..565d7f6 --- /dev/null +++ b/src/embedding-model.ts @@ -0,0 +1,97 @@ +import { EmbeddingModelV1, EmbeddingModelV1Embedding } from '@ai-sdk/provider'; +import { TextEmbedder, FilesetResolver } from '@mediapipe/tasks-text'; + +export interface ChromeAIEmbeddingModelSettings { + /** + * An optional base path to specify the directory the Wasm files should be loaded from. + * It's about 6mb before gzip. + * @default 'https://unpkg.com/@mediapipe/tasks-text/wasm/' + */ + filesetBasePath?: string; + /** + * The model path to the model asset file. + * It's about 6.1mb before gzip. + * @default 'https://storage.googleapis.com/mediapipe-models/text_embedder/universal_sentence_encoder/float32/1/universal_sentence_encoder.tflite' + */ + modelAssetPath?: string; + /** + * Whether to normalize the returned feature vector with L2 norm. Use this + * option only if the model does not already contain a native L2_NORMALIZATION + * TF Lite Op. In most cases, this is already the case and L2 norm is thus + * achieved through TF Lite inference. + * @default false + */ + l2Normalize?: boolean; + /** + * Whether the returned embedding should be quantized to bytes via scalar + * quantization. Embeddings are implicitly assumed to be unit-norm and + * therefore any dimension is guaranteed to have a value in [-1.0, 1.0]. Use + * the l2_normalize option if this is not the case. + * @default false + */ + quantize?: boolean; + /** + * Overrides the default backend to use for the provided model. + */ + delegate?: 'CPU' | 'GPU'; +} + +export class ChromeAIEmbeddingModel implements EmbeddingModelV1 { + readonly specificationVersion = 'v1'; + readonly provider = 'google-mediapipe'; + readonly modelId: string = 'mediapipe'; + readonly supportsParallelCalls = true; + readonly maxEmbeddingsPerCall = undefined; + + private settings: ChromeAIEmbeddingModelSettings = { + filesetBasePath: 'https://unpkg.com/@mediapipe/tasks-text/wasm/', + modelAssetPath: + 'https://storage.googleapis.com/mediapipe-models/text_embedder/universal_sentence_encoder/float32/1/universal_sentence_encoder.tflite', + l2Normalize: false, + quantize: false, + }; + private textEmbedder: TextEmbedder | null = null; + + public constructor(settings: ChromeAIEmbeddingModelSettings = {}) { + this.settings = { ...this.settings, ...settings }; + } + + protected getTextEmbedder = async (): Promise => { + if (this.textEmbedder !== null) return this.textEmbedder; + const textFiles = await FilesetResolver.forTextTasks( + this.settings.filesetBasePath + ); + this.textEmbedder = await TextEmbedder.createFromOptions(textFiles, { + baseOptions: { + modelAssetPath: this.settings.modelAssetPath, + delegate: this.settings.delegate, + }, + l2Normalize: this.settings.l2Normalize, + quantize: this.settings.quantize, + }); + return this.textEmbedder; + }; + + public doEmbed = async (options: { + values: string[]; + abortSignal?: AbortSignal; + }): Promise<{ + embeddings: Array; + rawResponse?: Record; + }> => { + // if (options.abortSignal) console.warn('abortSignal is not supported'); + + const embedder = await this.getTextEmbedder(); + const embeddings = await Promise.all( + options.values.map((text) => { + const embedderResult = embedder.embed(text); + const [embedding] = embedderResult.embeddings; + return embedding?.floatEmbedding ?? []; + }) + ); + return { embeddings }; + }; +} + +export const chromeEmbedding = (options?: ChromeAIEmbeddingModelSettings) => + new ChromeAIEmbeddingModel(options); diff --git a/src/index.ts b/src/index.ts index 6311830..b26e9d6 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1 +1,2 @@ export * from './language-model'; +export * from './embedding-model'; diff --git a/src/language-model.test.ts b/src/language-model.test.ts index 583a8bc..b396f86 100644 --- a/src/language-model.test.ts +++ b/src/language-model.test.ts @@ -7,7 +7,7 @@ import { } from '@ai-sdk/provider'; import { z } from 'zod'; -describe('chrome-ai', () => { +describe('language-model', () => { // Reset all stubs after each test afterEach(() => { vi.unstubAllGlobals(); diff --git a/src/language-model.ts b/src/language-model.ts index 07336e7..274f5b3 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -17,6 +17,7 @@ import { import { ChromeAISession, ChromeAISessionOptions } from './global'; import createDebug from 'debug'; import { StreamAI } from './stream-ai'; +import { chromeEmbedding } from './embedding-model'; const debug = createDebug('chromeai'); @@ -237,3 +238,5 @@ export const chromeai = ( modelId: ChromeAIChatModelId = 'generic', settings: ChromeAIChatSettings = {} ) => new ChromeAIChatLanguageModel(modelId, settings); + +chromeai.embedding = chromeEmbedding; From 864c1679c0eea0edbcf082e16b6f97eaff40e162 Mon Sep 17 00:00:00 2001 From: Jeason Date: Thu, 4 Jul 2024 21:30:40 +0800 Subject: [PATCH 2/2] chore: add changeset for embedding model --- .changeset/sweet-donuts-prove.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/sweet-donuts-prove.md diff --git a/.changeset/sweet-donuts-prove.md b/.changeset/sweet-donuts-prove.md new file mode 100644 index 0000000..b81931a --- /dev/null +++ b/.changeset/sweet-donuts-prove.md @@ -0,0 +1,5 @@ +--- +"chrome-ai": minor +--- + +feat: add embedding model for client side embed with ai sdk