From 95d7ab5391b0de66ba1e5d3298da4672eff67e32 Mon Sep 17 00:00:00 2001 From: Ian Philips Date: Wed, 26 Feb 2025 17:16:21 -0800 Subject: [PATCH] Infer unit/metric from question title --- backend/api/package.json | 1 + backend/api/src/infer-numeric-unit.ts | 50 +++++++++++ backend/api/src/routes.ts | 2 + backend/shared/package.json | 1 + backend/shared/src/helpers/gemini.ts | 82 +++++++++++++++++++ common/src/api/schema.ts | 14 ++++ common/src/secrets.ts | 1 + .../new-contract/contract-params-form.tsx | 20 ++++- .../multi-numeric-range-section.tsx | 6 +- yarn.lock | 5 ++ 10 files changed, 178 insertions(+), 4 deletions(-) create mode 100644 backend/api/src/infer-numeric-unit.ts create mode 100644 backend/shared/src/helpers/gemini.ts diff --git a/backend/api/package.json b/backend/api/package.json index 452b24d9af..6eb1fb5fa5 100644 --- a/backend/api/package.json +++ b/backend/api/package.json @@ -29,6 +29,7 @@ "@giphy/js-fetch-api": "5.0.0", "@google-cloud/monitoring": "4.0.0", "@google-cloud/secret-manager": "4.2.1", + "@google/generative-ai": "0.22.0", "@mendable/firecrawl-js": "1.8.5", "@supabase/supabase-js": "2.38.5", "@tiptap/core": "2.0.0-beta.204", diff --git a/backend/api/src/infer-numeric-unit.ts b/backend/api/src/infer-numeric-unit.ts new file mode 100644 index 0000000000..9f1549c274 --- /dev/null +++ b/backend/api/src/infer-numeric-unit.ts @@ -0,0 +1,50 @@ +import { APIError, APIHandler } from './helpers/endpoint' +import { track } from 'shared/analytics' +import { rateLimitByUser } from './helpers/rate-limit' +import { HOUR_MS } from 'common/util/time' +import { promptGemini, parseGeminiResponseAsJson } from 'shared/helpers/gemini' +import { log } from 'shared/utils' + +export const inferNumericUnit: APIHandler<'infer-numeric-unit'> = + rateLimitByUser( + async (props, auth) => { + const { question, description } = props + + try { + const systemPrompt = ` + You are an AI assistant that extracts the most appropriate unit of measurement from prediction market questions. + You will return ONLY a JSON object with a single "unit" field containing the inferred unit as a string. + For example: {"unit": "people"} + + Guidelines: + - If no specific unit is mentioned, infer the most logical unit based on the context + - Common units include: people, dollars, percent, points, votes, etc. + - If the question is about a count of items, use the plural form (e.g., "people" not "person") + - If no unit can be reasonably inferred, return an empty json object + ` + + const prompt = ` + Question: ${question} + ${ + description && description !== '

' + ? `Description: ${description}` + : '' + } + ` + const response = await promptGemini(prompt, { system: systemPrompt }) + const result = parseGeminiResponseAsJson(response) + log.info('Inferred unit:', { result }) + + track(auth.uid, 'infer-numeric-unit', { + question, + inferred_unit: result.unit, + }) + + return { unit: result.unit } + } catch (error) { + log.error('Error inferring unit:', { error }) + throw new APIError(500, 'Failed to infer unit from question') + } + }, + { maxCalls: 60, windowMs: HOUR_MS } + ) diff --git a/backend/api/src/routes.ts b/backend/api/src/routes.ts index 8b084e499b..781744f393 100644 --- a/backend/api/src/routes.ts +++ b/backend/api/src/routes.ts @@ -171,6 +171,7 @@ import { generateAIDateRanges, regenerateDateMidpoints, } from './generate-ai-date-ranges' +import { inferNumericUnit } from './infer-numeric-unit' // we define the handlers in this object in order to typecheck that every API has a handler export const handlers: { [k in APIPath]: APIHandler } = { 'refresh-all-clients': refreshAllClients, @@ -352,6 +353,7 @@ export const handlers: { [k in APIPath]: APIHandler } = { 'purchase-contract-boost': purchaseContractBoost, 'generate-ai-numeric-ranges': generateAINumericRanges, 'regenerate-numeric-midpoints': regenerateNumericMidpoints, + 'infer-numeric-unit': inferNumericUnit, 'generate-ai-date-ranges': generateAIDateRanges, 'regenerate-date-midpoints': regenerateDateMidpoints, } diff --git a/backend/shared/package.json b/backend/shared/package.json index 044e6b9346..f2082471f1 100644 --- a/backend/shared/package.json +++ b/backend/shared/package.json @@ -14,6 +14,7 @@ "@anthropic-ai/sdk": "0.24.3", "@google-cloud/monitoring": "4.0.0", "@google-cloud/secret-manager": "4.2.1", + "@google/generative-ai": "0.22.0", "@stdlib/math-base-special-betaincinv": "0.2.1", "@tiptap/core": "2.0.0-beta.204", "@tiptap/html": "2.0.0-beta.204", diff --git a/backend/shared/src/helpers/gemini.ts b/backend/shared/src/helpers/gemini.ts new file mode 100644 index 0000000000..f9a23cc0e2 --- /dev/null +++ b/backend/shared/src/helpers/gemini.ts @@ -0,0 +1,82 @@ +import { GoogleGenerativeAI } from '@google/generative-ai' +import { log } from 'shared/utils' +import { APIError } from 'common/api/utils' + +export const models = { + flash: 'gemini-2.0-flash' as const, +} + +export type model_types = (typeof models)[keyof typeof models] + +export const promptGemini = async ( + prompt: string, + options: { system?: string; model?: model_types } = {} +) => { + const { model = models.flash, system } = options + + const apiKey = process.env.GEMINI_API_KEY + + if (!apiKey) { + throw new APIError(500, 'Missing GEMINI_API_KEY') + } + + const genAI = new GoogleGenerativeAI(apiKey) + const geminiModel = genAI.getGenerativeModel({ model }) + + try { + // Combine system prompt and user prompt if system is provided + const fullPrompt = system ? `${system}\n\n${prompt}` : prompt + + const result = await geminiModel.generateContent(fullPrompt) + const response = result.response.text() + + log('Gemini returned message:', response) + return response + } catch (error: any) { + log.error(`Error with Gemini API: ${error.message}`) + throw new APIError(500, 'Failed to get response from Gemini') + } +} + +// Helper function to clean Gemini responses from markdown formatting +const removeJsonTicksFromGeminiResponse = (response: string): string => { + // Remove markdown code block formatting if present + const jsonBlockRegex = /```(?:json)?\s*([\s\S]*?)```/ + const match = response.match(jsonBlockRegex) + + if (match && match[1]) { + return match[1].trim() + } + + // If no markdown formatting found, return the original response + return response.trim() +} + +// Helper function to ensure the response is valid JSON +export const parseGeminiResponseAsJson = (response: string): any => { + const cleanedResponse = removeJsonTicksFromGeminiResponse(response) + + try { + // Try to parse as is + return JSON.parse(cleanedResponse) + } catch (error) { + // If parsing fails, try to handle common issues + + // Check if it's an array wrapped in extra text + const arrayStart = cleanedResponse.indexOf('[') + const arrayEnd = cleanedResponse.lastIndexOf(']') + + if (arrayStart !== -1 && arrayEnd !== -1 && arrayEnd > arrayStart) { + const potentialArray = cleanedResponse.substring(arrayStart, arrayEnd + 1) + try { + return JSON.parse(potentialArray) + } catch (e) { + // If still fails, throw the original error + throw error + } + } + + // If we can't fix it, throw the original error + throw error + } +} diff --git a/common/src/api/schema.ts b/common/src/api/schema.ts index dd5502deab..f32e505d80 100644 --- a/common/src/api/schema.ts +++ b/common/src/api/schema.ts @@ -2191,6 +2191,20 @@ export const API = (_apiTypeCheck = { }) .strict(), }, + 'infer-numeric-unit': { + method: 'POST', + visibility: 'public', + authed: true, + returns: {} as { + unit: string + }, + props: z + .object({ + question: z.string(), + description: z.string().optional(), + }) + .strict(), + }, 'generate-ai-date-ranges': { method: 'POST', visibility: 'public', diff --git a/common/src/secrets.ts b/common/src/secrets.ts index 465c5b9935..0a8989e2ff 100644 --- a/common/src/secrets.ts +++ b/common/src/secrets.ts @@ -38,6 +38,7 @@ export const secrets = ( 'FIRECRAWL_API_KEY', 'SPORTSDB_KEY', 'VERIFIED_PHONE_NUMBER', + 'GEMINI_API_KEY', // Some typescript voodoo to keep the string literal types while being not readonly. ] as const ).concat() diff --git a/web/components/new-contract/contract-params-form.tsx b/web/components/new-contract/contract-params-form.tsx index 4d23235b43..52ccfd32e3 100644 --- a/web/components/new-contract/contract-params-form.tsx +++ b/web/components/new-contract/contract-params-form.tsx @@ -595,6 +595,21 @@ export function ContractParamsForm(props: { } } + const inferUnit = async () => { + if (!question || unit !== '') return + try { + const result = await api('infer-numeric-unit', { + question, + description: editor?.getHTML(), + }) + if (result.unit) { + setUnit(result.unit) + } + } catch (e) { + console.error('Error inferring unit:', e) + } + } + return ( @@ -608,7 +623,10 @@ export function ContractParamsForm(props: { maxLength={MAX_QUESTION_LENGTH} value={question} onChange={(e) => setQuestion(e.target.value || '')} - onBlur={(e) => findTopicsAndSimilarQuestions(e.target.value || '')} + onBlur={(e) => { + if (outcomeType === 'MULTI_NUMERIC') inferUnit() + findTopicsAndSimilarQuestions(e.target.value || '') + }} /> {similarContracts.length ? ( diff --git a/web/components/new-contract/multi-numeric-range-section.tsx b/web/components/new-contract/multi-numeric-range-section.tsx index abc21a35ba..9ff8a61e30 100644 --- a/web/components/new-contract/multi-numeric-range-section.tsx +++ b/web/components/new-contract/multi-numeric-range-section.tsx @@ -224,7 +224,7 @@ export const MultiNumericRangeSection = (props: { useEffect(() => { if (isTimeUnit(unit)) { setError( - 'Time units are not supported for numeric ranges. Date ranges are coming soon!' + 'Time metrics are not supported for numeric ranges. Date ranges are coming soon!' ) } else { setError('') @@ -320,7 +320,7 @@ export const MultiNumericRangeSection = (props: { - Range & unit + Range & metric {minMaxError && ( @@ -355,7 +355,7 @@ export const MultiNumericRangeSection = (props: { e.stopPropagation()} onChange={(e) => setUnit(e.target.value)} onBlur={handleRangeBlur} diff --git a/yarn.lock b/yarn.lock index da069b0069..ac4411f9b7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2269,6 +2269,11 @@ teeny-request "^8.0.0" uuid "^8.0.0" +"@google/generative-ai@0.22.0": + version "0.22.0" + resolved "https://registry.yarnpkg.com/@google/generative-ai/-/generative-ai-0.22.0.tgz#e77a1a3911f4f98bf9e965ad7c4b1ee2130abd26" + integrity sha512-mLR3PDWCk5O/BWNyDvFDIiwKeXQmFGZ+kJFd9m73QrUPCFREttJyVbBPTW4y9CwTbaltLMDaLDfroCrRv5Bl8Q== + "@grpc/grpc-js@~1.10.0": version "1.10.6" resolved "https://registry.yarnpkg.com/@grpc/grpc-js/-/grpc-js-1.10.6.tgz#1e3eb1af911dc888fbef7452f56a7573b8284d54"