diff --git a/backend/api/package.json b/backend/api/package.json
index 452b24d9af..6eb1fb5fa5 100644
--- a/backend/api/package.json
+++ b/backend/api/package.json
@@ -29,6 +29,7 @@
"@giphy/js-fetch-api": "5.0.0",
"@google-cloud/monitoring": "4.0.0",
"@google-cloud/secret-manager": "4.2.1",
+ "@google/generative-ai": "0.22.0",
"@mendable/firecrawl-js": "1.8.5",
"@supabase/supabase-js": "2.38.5",
"@tiptap/core": "2.0.0-beta.204",
diff --git a/backend/api/src/infer-numeric-unit.ts b/backend/api/src/infer-numeric-unit.ts
new file mode 100644
index 0000000000..9f1549c274
--- /dev/null
+++ b/backend/api/src/infer-numeric-unit.ts
@@ -0,0 +1,50 @@
+import { APIError, APIHandler } from './helpers/endpoint'
+import { track } from 'shared/analytics'
+import { rateLimitByUser } from './helpers/rate-limit'
+import { HOUR_MS } from 'common/util/time'
+import { promptGemini, parseGeminiResponseAsJson } from 'shared/helpers/gemini'
+import { log } from 'shared/utils'
+
+export const inferNumericUnit: APIHandler<'infer-numeric-unit'> =
+ rateLimitByUser(
+ async (props, auth) => {
+ const { question, description } = props
+
+ try {
+ const systemPrompt = `
+ You are an AI assistant that extracts the most appropriate unit of measurement from prediction market questions.
+ You will return ONLY a JSON object with a single "unit" field containing the inferred unit as a string.
+ For example: {"unit": "people"}
+
+ Guidelines:
+ - If no specific unit is mentioned, infer the most logical unit based on the context
+ - Common units include: people, dollars, percent, points, votes, etc.
+ - If the question is about a count of items, use the plural form (e.g., "people" not "person")
+ - If no unit can be reasonably inferred, return an empty json object
+ `
+
+ const prompt = `
+ Question: ${question}
+ ${
+ description && description !== '
'
+ ? `Description: ${description}`
+ : ''
+ }
+ `
+ const response = await promptGemini(prompt, { system: systemPrompt })
+ const result = parseGeminiResponseAsJson(response)
+ log.info('Inferred unit:', { result })
+
+ track(auth.uid, 'infer-numeric-unit', {
+ question,
+ inferred_unit: result.unit,
+ })
+
+ return { unit: result.unit }
+ } catch (error) {
+ log.error('Error inferring unit:', { error })
+ throw new APIError(500, 'Failed to infer unit from question')
+ }
+ },
+ { maxCalls: 60, windowMs: HOUR_MS }
+ )
diff --git a/backend/api/src/routes.ts b/backend/api/src/routes.ts
index 8b084e499b..781744f393 100644
--- a/backend/api/src/routes.ts
+++ b/backend/api/src/routes.ts
@@ -171,6 +171,7 @@ import {
generateAIDateRanges,
regenerateDateMidpoints,
} from './generate-ai-date-ranges'
+import { inferNumericUnit } from './infer-numeric-unit'
// we define the handlers in this object in order to typecheck that every API has a handler
export const handlers: { [k in APIPath]: APIHandler } = {
'refresh-all-clients': refreshAllClients,
@@ -352,6 +353,7 @@ export const handlers: { [k in APIPath]: APIHandler } = {
'purchase-contract-boost': purchaseContractBoost,
'generate-ai-numeric-ranges': generateAINumericRanges,
'regenerate-numeric-midpoints': regenerateNumericMidpoints,
+ 'infer-numeric-unit': inferNumericUnit,
'generate-ai-date-ranges': generateAIDateRanges,
'regenerate-date-midpoints': regenerateDateMidpoints,
}
diff --git a/backend/shared/package.json b/backend/shared/package.json
index 044e6b9346..f2082471f1 100644
--- a/backend/shared/package.json
+++ b/backend/shared/package.json
@@ -14,6 +14,7 @@
"@anthropic-ai/sdk": "0.24.3",
"@google-cloud/monitoring": "4.0.0",
"@google-cloud/secret-manager": "4.2.1",
+ "@google/generative-ai": "0.22.0",
"@stdlib/math-base-special-betaincinv": "0.2.1",
"@tiptap/core": "2.0.0-beta.204",
"@tiptap/html": "2.0.0-beta.204",
diff --git a/backend/shared/src/helpers/gemini.ts b/backend/shared/src/helpers/gemini.ts
new file mode 100644
index 0000000000..f9a23cc0e2
--- /dev/null
+++ b/backend/shared/src/helpers/gemini.ts
@@ -0,0 +1,82 @@
+import { GoogleGenerativeAI } from '@google/generative-ai'
+import { log } from 'shared/utils'
+import { APIError } from 'common/api/utils'
+
+export const models = {
+ flash: 'gemini-2.0-flash' as const,
+}
+
+export type model_types = (typeof models)[keyof typeof models]
+
+export const promptGemini = async (
+ prompt: string,
+ options: { system?: string; model?: model_types } = {}
+) => {
+ const { model = models.flash, system } = options
+
+ const apiKey = process.env.GEMINI_API_KEY
+
+ if (!apiKey) {
+ throw new APIError(500, 'Missing GEMINI_API_KEY')
+ }
+
+ const genAI = new GoogleGenerativeAI(apiKey)
+ const geminiModel = genAI.getGenerativeModel({ model })
+
+ try {
+ // Combine system prompt and user prompt if system is provided
+ const fullPrompt = system ? `${system}\n\n${prompt}` : prompt
+
+ const result = await geminiModel.generateContent(fullPrompt)
+ const response = result.response.text()
+
+ log('Gemini returned message:', response)
+ return response
+ } catch (error: any) {
+ log.error(`Error with Gemini API: ${error.message}`)
+ throw new APIError(500, 'Failed to get response from Gemini')
+ }
+}
+
+// Helper function to clean Gemini responses from markdown formatting
+const removeJsonTicksFromGeminiResponse = (response: string): string => {
+ // Remove markdown code block formatting if present
+ const jsonBlockRegex = /```(?:json)?\s*([\s\S]*?)```/
+ const match = response.match(jsonBlockRegex)
+
+ if (match && match[1]) {
+ return match[1].trim()
+ }
+
+ // If no markdown formatting found, return the original response
+ return response.trim()
+}
+
+// Helper function to ensure the response is valid JSON
+export const parseGeminiResponseAsJson = (response: string): any => {
+ const cleanedResponse = removeJsonTicksFromGeminiResponse(response)
+
+ try {
+ // Try to parse as is
+ return JSON.parse(cleanedResponse)
+ } catch (error) {
+ // If parsing fails, try to handle common issues
+
+ // Check if it's an array wrapped in extra text
+ const arrayStart = cleanedResponse.indexOf('[')
+ const arrayEnd = cleanedResponse.lastIndexOf(']')
+
+ if (arrayStart !== -1 && arrayEnd !== -1 && arrayEnd > arrayStart) {
+ const potentialArray = cleanedResponse.substring(arrayStart, arrayEnd + 1)
+ try {
+ return JSON.parse(potentialArray)
+ } catch (e) {
+ // If still fails, throw the original error
+ throw error
+ }
+ }
+
+ // If we can't fix it, throw the original error
+ throw error
+ }
+}
diff --git a/common/src/api/schema.ts b/common/src/api/schema.ts
index dd5502deab..f32e505d80 100644
--- a/common/src/api/schema.ts
+++ b/common/src/api/schema.ts
@@ -2191,6 +2191,20 @@ export const API = (_apiTypeCheck = {
})
.strict(),
},
+ 'infer-numeric-unit': {
+ method: 'POST',
+ visibility: 'public',
+ authed: true,
+ returns: {} as {
+ unit: string
+ },
+ props: z
+ .object({
+ question: z.string(),
+ description: z.string().optional(),
+ })
+ .strict(),
+ },
'generate-ai-date-ranges': {
method: 'POST',
visibility: 'public',
diff --git a/common/src/secrets.ts b/common/src/secrets.ts
index 465c5b9935..0a8989e2ff 100644
--- a/common/src/secrets.ts
+++ b/common/src/secrets.ts
@@ -38,6 +38,7 @@ export const secrets = (
'FIRECRAWL_API_KEY',
'SPORTSDB_KEY',
'VERIFIED_PHONE_NUMBER',
+ 'GEMINI_API_KEY',
// Some typescript voodoo to keep the string literal types while being not readonly.
] as const
).concat()
diff --git a/web/components/new-contract/contract-params-form.tsx b/web/components/new-contract/contract-params-form.tsx
index 4d23235b43..52ccfd32e3 100644
--- a/web/components/new-contract/contract-params-form.tsx
+++ b/web/components/new-contract/contract-params-form.tsx
@@ -595,6 +595,21 @@ export function ContractParamsForm(props: {
}
}
+ const inferUnit = async () => {
+ if (!question || unit !== '') return
+ try {
+ const result = await api('infer-numeric-unit', {
+ question,
+ description: editor?.getHTML(),
+ })
+ if (result.unit) {
+ setUnit(result.unit)
+ }
+ } catch (e) {
+ console.error('Error inferring unit:', e)
+ }
+ }
+
return (
@@ -608,7 +623,10 @@ export function ContractParamsForm(props: {
maxLength={MAX_QUESTION_LENGTH}
value={question}
onChange={(e) => setQuestion(e.target.value || '')}
- onBlur={(e) => findTopicsAndSimilarQuestions(e.target.value || '')}
+ onBlur={(e) => {
+ if (outcomeType === 'MULTI_NUMERIC') inferUnit()
+ findTopicsAndSimilarQuestions(e.target.value || '')
+ }}
/>
{similarContracts.length ? (
diff --git a/web/components/new-contract/multi-numeric-range-section.tsx b/web/components/new-contract/multi-numeric-range-section.tsx
index abc21a35ba..9ff8a61e30 100644
--- a/web/components/new-contract/multi-numeric-range-section.tsx
+++ b/web/components/new-contract/multi-numeric-range-section.tsx
@@ -224,7 +224,7 @@ export const MultiNumericRangeSection = (props: {
useEffect(() => {
if (isTimeUnit(unit)) {
setError(
- 'Time units are not supported for numeric ranges. Date ranges are coming soon!'
+ 'Time metrics are not supported for numeric ranges. Date ranges are coming soon!'
)
} else {
setError('')
@@ -320,7 +320,7 @@ export const MultiNumericRangeSection = (props: {