Skip to content

Commit

Permalink
refactor: apply clem comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthieu-OD committed Nov 22, 2024
1 parent b03d76d commit d2b3b53
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 62 deletions.
15 changes: 8 additions & 7 deletions src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import { v4 as uuidv4 } from 'uuid';

import { LiteralClient } from '.';
import { sharedCache } from './cache/sharedcache';
import { getPromptCacheKey, putPrompt } from './cache/utils';
import { getPromptCacheKey } from './cache/utils';
import {
Dataset,
DatasetExperiment,
Expand Down Expand Up @@ -2152,12 +2152,9 @@ export class API {
*/
private async getPromptWithQuery(
query: string,
variables: Record<string, any>
variables: { id?: string; name?: string; version?: number }
) {
const { id, name, version } = variables;
const cachedPrompt = sharedCache.get(
getPromptCacheKey({ id, name, version })
);
const cachedPrompt = sharedCache.get(getPromptCacheKey(variables));
const timeout = cachedPrompt ? 1000 : undefined;

try {
Expand All @@ -2176,7 +2173,11 @@ export class API {
}

const prompt = new Prompt(this, promptData);
putPrompt(prompt);

sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);

return prompt;
} catch (error) {
return cachedPrompt;
Expand Down
9 changes: 0 additions & 9 deletions src/cache/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import { Prompt } from '../prompt-engineering/prompt';
import { sharedCache } from './sharedcache';

export function getPromptCacheKey({
id,
name,
Expand All @@ -19,9 +16,3 @@ export function getPromptCacheKey({
}
throw new Error('Either id or name must be provided');
}

export function putPrompt(prompt: Prompt): void {
sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);
}
11 changes: 8 additions & 3 deletions tests/api.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { v4 as uuidv4 } from 'uuid';

import { ChatGeneration, IGenerationMessage, LiteralClient } from '../src';
import { sharedCache } from '../src/cache/sharedcache';
import { putPrompt } from '../src/cache/utils';
import { Dataset } from '../src/evaluation/dataset';
import { Score } from '../src/evaluation/score';
import { Prompt, PromptConstructor } from '../src/prompt-engineering/prompt';
Expand Down Expand Up @@ -687,7 +686,9 @@ is a templated list.`;

it('should fallback to cache when getPromptById DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
putPrompt(prompt);
sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);

jest
.spyOn(client.api as any, 'makeGqlCall')
Expand All @@ -699,7 +700,11 @@ is a templated list.`;

it('should fallback to cache when getPrompt DB call fails', async () => {
const prompt = new Prompt(client.api, mockPromptData);
putPrompt(prompt);

sharedCache.put(prompt.id, prompt);
sharedCache.put(prompt.name, prompt);
sharedCache.put(`${prompt.name}:${prompt.version}`, prompt);

jest.spyOn(axios, 'post').mockRejectedValueOnce(new Error('DB Error'));

const result = await client.api.getPrompt(prompt.id);
Expand Down
44 changes: 1 addition & 43 deletions tests/cache.test.ts
Original file line number Diff line number Diff line change
@@ -1,41 +1,9 @@
import { API } from '../src/api';
import { sharedCache } from '../src/cache/sharedcache';
import { getPromptCacheKey, putPrompt } from '../src/cache/utils';
import { Prompt, PromptConstructor } from '../src/prompt-engineering/prompt';
import { getPromptCacheKey } from '../src/cache/utils';

describe('Cache', () => {
let api: API;
let mockPrompt: Prompt;

beforeAll(() => {
api = {} as API;
});

beforeEach(() => {
sharedCache.clear();

const mockPromptData: PromptConstructor = {
id: 'test-id',
type: 'CHAT',
createdAt: '2023-01-01T00:00:00Z',
name: 'test-name',
version: 1,
metadata: {},
items: [],
templateMessages: [{ role: 'user', content: 'Hello', uuid: '123' }],
provider: 'test-provider',
settings: {
provider: 'test-provider',
model: 'test-model',
frequency_penalty: 0,
max_tokens: 100,
presence_penalty: 0,
temperature: 0.7,
top_p: 1
},
variables: []
};
mockPrompt = new Prompt(api, mockPromptData);
});

describe('Cache Utils', () => {
Expand Down Expand Up @@ -68,16 +36,6 @@ describe('Cache', () => {
);
});
});

describe('putPrompt', () => {
it('should store prompt with multiple keys', () => {
putPrompt(mockPrompt);

expect(sharedCache.get('test-id')).toEqual(mockPrompt);
expect(sharedCache.get('test-name')).toEqual(mockPrompt);
expect(sharedCache.get('test-name:1')).toEqual(mockPrompt);
});
});
});

describe('SharedCache', () => {
Expand Down

0 comments on commit d2b3b53

Please sign in to comment.