-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Infer unit/metric from question title
- Loading branch information
1 parent
99e41de
commit 95d7ab5
Showing
10 changed files
with
178 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import { APIError, APIHandler } from './helpers/endpoint' | ||
import { track } from 'shared/analytics' | ||
import { rateLimitByUser } from './helpers/rate-limit' | ||
import { HOUR_MS } from 'common/util/time' | ||
import { promptGemini, parseGeminiResponseAsJson } from 'shared/helpers/gemini' | ||
import { log } from 'shared/utils' | ||
|
||
export const inferNumericUnit: APIHandler<'infer-numeric-unit'> = | ||
rateLimitByUser( | ||
async (props, auth) => { | ||
const { question, description } = props | ||
|
||
try { | ||
const systemPrompt = ` | ||
You are an AI assistant that extracts the most appropriate unit of measurement from prediction market questions. | ||
You will return ONLY a JSON object with a single "unit" field containing the inferred unit as a string. | ||
For example: {"unit": "people"} | ||
Guidelines: | ||
- If no specific unit is mentioned, infer the most logical unit based on the context | ||
- Common units include: people, dollars, percent, points, votes, etc. | ||
- If the question is about a count of items, use the plural form (e.g., "people" not "person") | ||
- If no unit can be reasonably inferred, return an empty json object | ||
` | ||
|
||
const prompt = ` | ||
Question: ${question} | ||
${ | ||
description && description !== '<p></p>' | ||
? `Description: ${description}` | ||
: '' | ||
} | ||
` | ||
const response = await promptGemini(prompt, { system: systemPrompt }) | ||
const result = parseGeminiResponseAsJson(response) | ||
log.info('Inferred unit:', { result }) | ||
|
||
track(auth.uid, 'infer-numeric-unit', { | ||
question, | ||
inferred_unit: result.unit, | ||
}) | ||
|
||
return { unit: result.unit } | ||
} catch (error) { | ||
log.error('Error inferring unit:', { error }) | ||
throw new APIError(500, 'Failed to infer unit from question') | ||
} | ||
}, | ||
{ maxCalls: 60, windowMs: HOUR_MS } | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import { GoogleGenerativeAI } from '@google/generative-ai' | ||
import { log } from 'shared/utils' | ||
import { APIError } from 'common/api/utils' | ||
|
||
export const models = { | ||
flash: 'gemini-2.0-flash' as const, | ||
} | ||
|
||
export type model_types = (typeof models)[keyof typeof models] | ||
|
||
export const promptGemini = async ( | ||
prompt: string, | ||
options: { system?: string; model?: model_types } = {} | ||
) => { | ||
const { model = models.flash, system } = options | ||
|
||
const apiKey = process.env.GEMINI_API_KEY | ||
|
||
if (!apiKey) { | ||
throw new APIError(500, 'Missing GEMINI_API_KEY') | ||
} | ||
|
||
const genAI = new GoogleGenerativeAI(apiKey) | ||
const geminiModel = genAI.getGenerativeModel({ model }) | ||
|
||
try { | ||
// Combine system prompt and user prompt if system is provided | ||
const fullPrompt = system ? `${system}\n\n${prompt}` : prompt | ||
|
||
const result = await geminiModel.generateContent(fullPrompt) | ||
const response = result.response.text() | ||
|
||
log('Gemini returned message:', response) | ||
return response | ||
} catch (error: any) { | ||
log.error(`Error with Gemini API: ${error.message}`) | ||
throw new APIError(500, 'Failed to get response from Gemini') | ||
} | ||
} | ||
|
||
// Helper function to clean Gemini responses from markdown formatting | ||
const removeJsonTicksFromGeminiResponse = (response: string): string => { | ||
// Remove markdown code block formatting if present | ||
const jsonBlockRegex = /```(?:json)?\s*([\s\S]*?)```/ | ||
const match = response.match(jsonBlockRegex) | ||
|
||
if (match && match[1]) { | ||
return match[1].trim() | ||
} | ||
|
||
// If no markdown formatting found, return the original response | ||
return response.trim() | ||
} | ||
|
||
// Helper function to ensure the response is valid JSON | ||
export const parseGeminiResponseAsJson = (response: string): any => { | ||
const cleanedResponse = removeJsonTicksFromGeminiResponse(response) | ||
|
||
try { | ||
// Try to parse as is | ||
return JSON.parse(cleanedResponse) | ||
} catch (error) { | ||
// If parsing fails, try to handle common issues | ||
|
||
// Check if it's an array wrapped in extra text | ||
const arrayStart = cleanedResponse.indexOf('[') | ||
const arrayEnd = cleanedResponse.lastIndexOf(']') | ||
|
||
if (arrayStart !== -1 && arrayEnd !== -1 && arrayEnd > arrayStart) { | ||
const potentialArray = cleanedResponse.substring(arrayStart, arrayEnd + 1) | ||
try { | ||
return JSON.parse(potentialArray) | ||
} catch (e) { | ||
// If still fails, throw the original error | ||
throw error | ||
} | ||
} | ||
|
||
// If we can't fix it, throw the original error | ||
throw error | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters