Skip to content

Commit

Permalink
fix(provider-transformers|stage-ui|stage-web): fix embed error, added…
Browse files Browse the repository at this point in the history
… vector tests, implemented provider config

Signed-off-by: Neko Ayaka <neko@ayaka.moe>
  • Loading branch information
nekomeowww committed Feb 27, 2025
1 parent 49be0b2 commit ab893fc
Show file tree
Hide file tree
Showing 17 changed files with 147 additions and 57 deletions.
3 changes: 3 additions & 0 deletions apps/stage-web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
"@xsai/stream-text": "catalog:",
"@xsai/utils-chat": "catalog:",
"defu": "^6.1.4",
"drizzle-kit": "^0.30.5",
"drizzle-orm": "^0.40.0",
"jszip": "^3.10.1",
"nprogress": "^0.2.0",
"ofetch": "^1.4.1",
Expand Down Expand Up @@ -92,6 +94,7 @@
"@iconify-json/svg-spinners": "^1.2.2",
"@iconify/utils": "^2.3.0",
"@intlify/unplugin-vue-i18n": "^6.0.3",
"@proj-airi/drizzle-duckdb-wasm": "workspace:^",
"@proj-airi/elevenlabs": "workspace:^",
"@proj-airi/lobe-icons": "workspace:^",
"@proj-airi/provider-transformers": "workspace:^",
Expand Down
14 changes: 8 additions & 6 deletions apps/stage-web/src/components/Widgets/ModelProviderSettings.vue
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
<script setup lang="ts">
import { Collapsable } from '@proj-airi/stage-ui/components'
import { useProvidersStore } from '@proj-airi/stage-ui/stores'
import { toJsonSchema } from '@valibot/to-json-schema'
import { storeToRefs } from 'pinia'
import { description, object, optional, pipe, record, string, title } from 'valibot'
import { computed, ref } from 'vue'
import { computed } from 'vue'
interface ModelProvider {
id: string
Expand Down Expand Up @@ -158,13 +160,13 @@ const providers = computed<ModelProvider[]>(() => [
},
])
const providerValues = ref<Record<string, Record<string, string>>>({})
const { providers: providerValues } = storeToRefs(useProvidersStore())
function getFieldValue(providerId: string, fieldName: string): string {
function getFieldValue(providerId: string, fieldName: string): unknown {
return providerValues.value[providerId]?.[fieldName] || ''
}
function setFieldValue(providerId: string, fieldName: string, value: string) {
function setFieldValue(providerId: string, fieldName: string, value: unknown) {
if (!providerValues.value[providerId]) {
providerValues.value[providerId] = {}
}
Expand All @@ -176,7 +178,7 @@ function getRecordEntries(providerId: string, fieldName: string): Array<[string,
if (!value)
return [['', '']]
try {
return Object.entries(JSON.parse(value))
return Object.entries(value)
}
catch {
return [['', '']]
Expand All @@ -191,7 +193,7 @@ function setRecordValue(providerId: string, fieldName: string, entries: Array<[s
}
const record = Object.fromEntries(validEntries)
setFieldValue(providerId, fieldName, JSON.stringify(record))
setFieldValue(providerId, fieldName, record)
}
function addRecordEntry(entries: Array<[string, string]>) {
Expand Down
4 changes: 2 additions & 2 deletions apps/stage-web/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ export default defineConfig({

resolve: {
alias: {
'@proj-airi/stage-ui': resolve(join(import.meta.dirname, '..', '..', 'packages', 'stage-ui', 'dist')),
'@proj-airi/stage-ui/stores': resolve(join(import.meta.dirname, '..', '..', 'packages', 'stage-ui', 'dist', 'stores')),
'@proj-airi/stage-ui': resolve(join(import.meta.dirname, '..', '..', 'packages', 'stage-ui', 'src')),
'@proj-airi/stage-ui/stores': resolve(join(import.meta.dirname, '..', '..', 'packages', 'stage-ui', 'src', 'stores')),
},
},

Expand Down
2 changes: 2 additions & 0 deletions cspell.config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ words:
- neuri
- Neuro
- Neuro-sama
- nomic
- novita
- nuxi
- nuxt
Expand Down Expand Up @@ -131,6 +132,7 @@ words:
- superjson
- tamagotchi
- taze
- tolist
- tresjs
- typeschema
- unhead
Expand Down
8 changes: 8 additions & 0 deletions packages/drizzle-duckdb-wasm/src/driver.browser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,12 @@ describe('drizzle with duckdb wasm in browser', { timeout: 10000 }, async () =>
expect(await db2.execute('SHOW TABLES')).toEqual([{ name: 'test' }])
expect(await db2.execute('SELECT * FROM test')).toEqual([{ v: 1 }, { v: 2 }, { v: 3 }])
})

it('should create a table with a float array column', async () => {
const db = drizzle({ connection: { bundles: getImportUrlBundles() } })
await db.execute('CREATE TABLE test (v FLOAT[26880], v2 text)')
await db.execute('INSERT INTO test VALUES (1, 2, 3, 4, "test")')
const res = await db.execute('SELECT * FROM test')
expect(res).toEqual([{ v: [1, 2, 3, 4], v2: 'test' }])
})
})
8 changes: 8 additions & 0 deletions packages/drizzle-duckdb-wasm/src/driver.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,12 @@ describe('drizzle with duckdb wasm in node', { timeout: 10000 }, async () => {
expect(await db2.execute('SHOW TABLES')).toEqual([{ name: 'test' }])
expect(await db2.execute('SELECT * FROM test')).toEqual([{ v: 1 }, { v: 2 }, { v: 3 }])
})

it('should create a table with a float array column', async () => {
const db = drizzle({ connection: { bundles: getBundles() } })
await db.execute('CREATE TABLE vector_test_table (v FLOAT[26880], v2 text)')
await db.execute(`INSERT INTO vector_test_table VALUES (${JSON.stringify(Array.from({ length: 26880 }).fill(1))}, 'text')`)
const res = await db.execute('SELECT * FROM vector_test_table')
expect(res).toEqual([{ v: Array.from({ length: 26880 }).fill(1), v2: 'text' }])
})
})
20 changes: 13 additions & 7 deletions packages/provider-transformers/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
import type { CreateProviderOptions, EmbedProviderWithExtraOptions } from '@xsai-ext/shared-providers'
import type { EmbedResponse } from '@xsai/embed'
import type { CommonRequestOptions } from '@xsai/shared'
import type { LoadOptions, WorkerMessageEvent } from './types'
import type { LoadOptionProgressCallback, LoadOptions, WorkerMessageEvent } from './types'

import { merge } from '@xsai-ext/shared-providers'
import defu from 'defu'

export type Loadable<P, T = string, T2 = undefined> = P & {
loadEmbed: (model: (string & {}) | T, options?: T2) => Promise<void>
terminateEmbed: () => void
}

export function createEmbedProvider<T extends string, T2 extends Omit<CommonRequestOptions, 'baseURL' | 'model'> & LoadOptions>(createOptions: CreateProviderOptions): Loadable<EmbedProviderWithExtraOptions<T, T2>, T, T2> {
function createEmbedProvider<T extends string, T2 extends Omit<CommonRequestOptions, 'baseURL' | 'model'> & LoadOptions>(createOptions: CreateProviderOptions): Loadable<EmbedProviderWithExtraOptions<T, T2>, T, T2> {
let worker: Worker
let isReady = false

function loadModel(model: (string & {}) | T, options: T2) {
function loadModel(model: (string & {}) | T, options?: T2) {
return new Promise<void>((resolve, reject) => {
const onProgress = options.onProgress
delete options.onProgress
let onProgress: LoadOptionProgressCallback | undefined
if (options != null && 'onProgress' in options && options.onProgress != null) {
onProgress = options?.onProgress
delete options?.onProgress
}

try {
const workerURL = new URL(createOptions.baseURL)
Expand Down Expand Up @@ -95,18 +99,20 @@ export function createEmbedProvider<T extends string, T2 extends Omit<CommonRequ
break
case 'extractResult':
resultDone = true

// eslint-disable-next-line no-case-declarations
const result = { data: [{ embedding: event.data.data.output.data, index: 0, object: 'embedding' }], model, object: 'list', usage: { prompt_tokens: 0, total_tokens: 0 } } satisfies EmbedResponse
// eslint-disable-next-line no-case-declarations
const encoder = new TextEncoder()

resolve(new Response(encoder.encode(JSON.stringify(result))))

break
}
})

if (!errored && !resultDone)
worker.postMessage({ type: 'extract', data: { text, options: body as any } } satisfies WorkerMessageEvent)
worker.postMessage({ type: 'extract', data: { text, options: defu<LoadOptions, LoadOptions[]>(options, { pooling: 'mean', normalize: true }) } } satisfies WorkerMessageEvent)
})
})
},
Expand All @@ -123,6 +129,6 @@ export function createEmbedProvider<T extends string, T2 extends Omit<CommonRequ

export function createTransformers(options: { embedWorkerURL: string }) {
return merge(
createEmbedProvider<'Xenova/all-MiniLM-L6-v2', Omit<CreateProviderOptions, 'baseURL'> & LoadOptions>({ baseURL: `xsai-provider-ext:///?worker-url=${options.embedWorkerURL}&other=` }),
createEmbedProvider<'Xenova/all-MiniLM-L6-v2', Omit<CreateProviderOptions, 'baseURL'> & LoadOptions>({ baseURL: `xsai-provider-ext:///?worker-url=${options.embedWorkerURL}&other=true` }),
)
}
2 changes: 1 addition & 1 deletion packages/provider-transformers/src/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export enum MessageStatus {
Ready = 'ready',
}

export type LoadOptions = Omit<PretrainedOptions & ModelSpecificPretrainedOptions, 'progress_callback'> & { onProgress?: LoadOptionProgressCallback }
export type LoadOptions = Omit<PretrainedOptions & ModelSpecificPretrainedOptions, 'progress_callback'> & { onProgress?: LoadOptionProgressCallback } & FeatureExtractionPipelineOptions
export type LoadOptionProgressCallback = (progress: ProgressInfo) => void | Promise<void>
export type { ProgressInfo }

Expand Down
3 changes: 2 additions & 1 deletion packages/provider-transformers/src/worker/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ async function load(modelId: string, options?: Omit<PipelineOptionsFrom<typeof p

async function extract(text: string | string[], options?: FeatureExtractionPipelineOptions) {
const result = await embed(text, options)
self.postMessage({ type: 'extractResult', data: { input: { text, options }, output: { data: Array.from(result.data), dims: result.dims } } } satisfies WorkerMessageEvent)
const resultArray = result.tolist()
self.postMessage({ type: 'extractResult', data: { input: { text, options }, output: { data: Array.from(resultArray[0] || []), dims: result.dims } } } satisfies WorkerMessageEvent)
}

self.addEventListener('message', (event: MessageEvent<WorkerMessageEvent>) => {
Expand Down
5 changes: 4 additions & 1 deletion packages/stage-ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@
"typecheck": "vue-tsc --noEmit"
},
"devDependencies": {
"@electron-toolkit/preload": "^3.0.1"
"@electron-toolkit/preload": "^3.0.1",
"@proj-airi/provider-transformers": "workspace:^",
"@proj-airi/utils-transformers": "workspace:^",
"@xsai/embed": "catalog:"
}
}
32 changes: 25 additions & 7 deletions packages/stage-ui/src/components/Widgets/Stage.vue
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
<script setup lang="ts">
import type { ElectronAPI } from '@electron-toolkit/preload'
import type { DuckDBWasmDrizzleDatabase } from '@proj-airi/drizzle-duckdb-wasm'
import type { Emotion } from '../../constants/emotions'
import { drizzle } from '@proj-airi/drizzle-duckdb-wasm'
// import { createTransformers } from '@proj-airi/provider-transformers'
// import embedWorkerURL from '@proj-airi/provider-transformers/worker?worker&url'
// import { embed } from '@xsai/embed'
import { generateSpeech } from '@xsai/generate-speech'
import { createUnElevenLabs } from '@xsai/providers'
import { sql } from 'drizzle-orm'
import { storeToRefs } from 'pinia'
import { onMounted, onUnmounted, ref } from 'vue'
import { useI18n } from 'vue-i18n'
Expand All @@ -22,20 +28,18 @@ import VRMScene from '../Scenes/VRM.vue'
import '../../utils/live2d-zip-loader'
withDefaults(defineProps<{
paused?: boolean
}>(), {
paused: false,
})
withDefaults(defineProps<{ paused?: boolean }>(), { paused: false })
const vrmViewerRef = ref<{ setExpression: (expression: string) => void }>()
const db = ref<DuckDBWasmDrizzleDatabase>()
// const transformersProvider = createTransformers({ embedWorkerURL })
const vrmViewerRef = ref<{ setExpression: (expression: string) => void }>()
const motion = ref('')
const { stageView, elevenLabsApiKey, elevenlabsVoiceEnglish, elevenlabsVoiceJapanese } = storeToRefs(useSettings())
const { mouthOpenSize } = storeToRefs(useSpeakingStore())
const { audioContext, calculateVolume } = useAudioContext()
const { onBeforeMessageComposed, onBeforeSend, onTokenLiteral, onTokenSpecial, onStreamEnd, streamingMessage } = useChatStore()
const { onBeforeMessageComposed, onBeforeSend, onTokenLiteral, onTokenSpecial, onStreamEnd, streamingMessage, onAssistantResponseEnd } = useChatStore()
const { process } = useMarkdown()
const { locale } = useI18n()
Expand Down Expand Up @@ -184,6 +188,15 @@ onStreamEnd(async () => {
await delaysQueue.add(llmInferenceEndToken)
})
onAssistantResponseEnd(async (_message) => {
// const res = await embed({
// ...transformersProvider.embed('Xenova/nomic-embed-text-v1'),
// input: message,
// })
// await db.value?.execute(`INSERT INTO memory_test (vec) VALUES (${JSON.stringify(res.embedding)});`)
})
onUnmounted(() => {
lipSyncStarted.value = false
Expand All @@ -207,6 +220,11 @@ onMounted(() => {
motion.value = EmotionThinkMotionName
})
})
onMounted(async () => {
db.value = drizzle('duckdb-wasm://?bundles=import-url')
await db.value.execute(sql`CREATE TABLE memory_test (vec FLOAT[768]);`)
})
</script>

<template>
Expand Down
24 changes: 17 additions & 7 deletions packages/stage-ui/src/stores/chat.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
import type { AssistantMessage, Message } from '@xsai/shared-chat'

import { defineStore, storeToRefs } from 'pinia'
import { ref } from 'vue'
import { ref, toRaw } from 'vue'
import { useI18n } from 'vue-i18n'

import { useLlmmarkerParser } from '../composables/llmmarkerParser'
import SystemPromptV2 from '../constants/prompts/system-v2'
import { useLLM } from '../stores/llm'
import { useSettings } from '../stores/settings'
import { useProvidersStore } from '../stores/providers'
import { asyncIteratorFromReadableStream } from '../utils/iterator'

export const useChatStore = defineStore('chat', () => {
const { stream } = useLLM()
const { t } = useI18n()
const { openAiApiBaseURL, openAiApiKey, openAiModel } = storeToRefs(useSettings())
const { providers: providerValues } = storeToRefs(useProvidersStore())

const onBeforeMessageComposedHooks = ref<Array<(message: string) => Promise<void>>>([])
const onAfterMessageComposedHooks = ref<Array<(message: string) => Promise<void>>>([])
Expand All @@ -22,6 +22,7 @@ export const useChatStore = defineStore('chat', () => {
const onTokenLiteralHooks = ref<Array<(literal: string) => Promise<void>>>([])
const onTokenSpecialHooks = ref<Array<(special: string) => Promise<void>>>([])
const onStreamEndHooks = ref<Array<() => Promise<void>>>([])
const onAssistantResponseEndHooks = ref<Array<(message: string) => Promise<void>>>([])

function onBeforeMessageComposed(cb: (message: string) => Promise<void>) {
onBeforeMessageComposedHooks.value.push(cb)
Expand Down Expand Up @@ -51,6 +52,10 @@ export const useChatStore = defineStore('chat', () => {
onStreamEndHooks.value.push(cb)
}

function onAssistantResponseEnd(cb: (message: string) => Promise<void>) {
onAssistantResponseEndHooks.value.push(cb)
}

const messages = ref<Array<Message>>([
SystemPromptV2(
t('prompt.prefix'),
Expand All @@ -73,15 +78,15 @@ export const useChatStore = defineStore('chat', () => {
}

const {
baseUrl = openAiApiBaseURL.value,
apiKey = openAiApiKey.value,
model = openAiModel.value,
baseUrl = providerValues.value['openrouter-ai']?.baseUrl as string | undefined || '',
apiKey = providerValues.value['openrouter-ai']?.apiKey as string | undefined || '',
model = providerValues.value['openrouter-ai']?.model as { id: string } | undefined || { id: 'openai/gpt-4o-mini' },
} = options ?? { }

streamingMessage.value = { role: 'assistant', content: '' }
messages.value.push({ role: 'user', content: sendingMessage })
messages.value.push(streamingMessage.value)
const newMessages = messages.value.slice(0, messages.value.length - 1)
const newMessages = messages.value.slice(0, messages.value.length - 1).map(msg => toRaw(msg))

for (const hook of onAfterMessageComposedHooks.value) {
await hook(sendingMessage)
Expand Down Expand Up @@ -125,6 +130,10 @@ export const useChatStore = defineStore('chat', () => {
await hook()
}

for (const hook of onAssistantResponseEndHooks.value) {
await hook(fullText)
}

// eslint-disable-next-line no-console
console.debug('LLM output:', fullText)
}
Expand All @@ -140,5 +149,6 @@ export const useChatStore = defineStore('chat', () => {
onTokenLiteral,
onTokenSpecial,
onStreamEnd,
onAssistantResponseEnd,
}
})
1 change: 1 addition & 0 deletions packages/stage-ui/src/stores/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export * from './audio'
export * from './chat'
export * from './llm'
export * from './providers'
export * from './settings'
10 changes: 10 additions & 0 deletions packages/stage-ui/src/stores/providers.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { useLocalStorage } from '@vueuse/core'
import { defineStore } from 'pinia'

export const useProvidersStore = defineStore('providers', () => {
const providers = useLocalStorage<Record<string, Record<string, unknown>>>('settings/credentials/providers', {})

return {
providers,
}
})
3 changes: 2 additions & 1 deletion packages/stage-ui/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"moduleResolution": "Bundler",
"resolveJsonModule": true,
"types": [
"vitest"
"vitest",
"vite/client"
],
"allowJs": true,
"strict": true,
Expand Down
Loading

0 comments on commit ab893fc

Please sign in to comment.