Skip to content

Commit

Permalink
Feat(prompt) system prompt (#40)
Browse files Browse the repository at this point in the history
Add a basic system prompt in llm server.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced chat streaming functionality to accept a more structured
input, allowing for improved message handling.
- Introduced a new system prompt for the AI, enhancing its capabilities
in code generation.

- **Bug Fixes**
- Updated error handling for model tag fetching to provide clearer error
messages.

- **Documentation**
- Added a new interface `GenerateMessageParams` to define the structure
for message generation parameters.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: Jackson Chen <541898146chen@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent 92e5065 commit 1d706f1
Show file tree
Hide file tree
Showing 9 changed files with 82 additions and 19 deletions.
2 changes: 1 addition & 1 deletion backend/src/chat/chat.resolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ export class ChatResolver {
MessageRole.User,
);

const iterator = this.chatProxyService.streamChat(input.message);
const iterator = this.chatProxyService.streamChat(input);
let accumulatedContent = '';

for await (const chunk of iterator) {
Expand Down
16 changes: 12 additions & 4 deletions backend/src/chat/chat.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import { Message, MessageRole } from 'src/chat/message.model';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository } from 'typeorm';
import { User } from 'src/user/user.model';
import { NewChatInput, UpdateChatTitleInput } from 'src/chat/dto/chat.input';
import {
ChatInput,
NewChatInput,
UpdateChatTitleInput,
} from 'src/chat/dto/chat.input';

type CustomAsyncIterableIterator<T> = AsyncIterator<T> & {
[Symbol.asyncIterator](): AsyncIterableIterator<T>;
Expand All @@ -17,8 +21,12 @@ export class ChatProxyService {

constructor(private httpService: HttpService) {}

streamChat(input: string): CustomAsyncIterableIterator<ChatCompletionChunk> {
this.logger.debug('request chat input: ' + input);
streamChat(
input: ChatInput,
): CustomAsyncIterableIterator<ChatCompletionChunk> {
this.logger.debug(
`Request chat input: ${input.message} with model: ${input.model}`,
);
let isDone = false;
let responseSubscription: any;
const chunkQueue: ChatCompletionChunk[] = [];
Expand Down Expand Up @@ -60,7 +68,7 @@ export class ChatProxyService {
responseSubscription = this.httpService
.post(
'http://localhost:3001/chat/completion',
{ content: input },
{ content: input.message, model: input.model },
{ responseType: 'stream' },
)
.subscribe({
Expand Down
5 changes: 3 additions & 2 deletions llm-server/src/llm-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { ModelProvider } from './model/model-provider';
import { OpenAIModelProvider } from './model/openai-model-provider';
import { LlamaModelProvider } from './model/llama-model-provider';
import { Logger } from '@nestjs/common';
import { GenerateMessageParams } from './type/GenerateMessage';

export interface ChatMessageInput {
content: string;
Expand Down Expand Up @@ -32,10 +33,10 @@ export class LLMProvider {
}

async generateStreamingResponse(
content: string,
params: GenerateMessageParams,
res: Response,
): Promise<void> {
await this.modelProvider.generateStreamingResponse(content, res);
await this.modelProvider.generateStreamingResponse(params, res);
}

async getModelTags(res: Response): Promise<void> {
Expand Down
14 changes: 12 additions & 2 deletions llm-server/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Logger } from '@nestjs/common';
import { ChatMessageInput, LLMProvider } from './llm-provider';
import express, { Express, Request, Response } from 'express';
import { GenerateMessageParams } from './type/GenerateMessage';

export class App {
private readonly logger = new Logger(App.name);
Expand All @@ -27,13 +28,22 @@ export class App {
this.logger.log('Received chat request.');
try {
this.logger.debug(JSON.stringify(req.body));
const { content } = req.body as ChatMessageInput;
const { content, model } = req.body as ChatMessageInput & {
model: string;
};

const params: GenerateMessageParams = {
model: model || 'gpt-3.5-turbo', // Default to 'gpt-3.5-turbo' if model is not provided
message: content,
role: 'user',
};

this.logger.debug(`Request content: "${content}"`);
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
this.logger.debug('Response headers set for streaming.');
await this.llmProvider.generateStreamingResponse(content, res);
await this.llmProvider.generateStreamingResponse(params, res);
} catch (error) {
this.logger.error('Error in chat endpoint:', error);
res.status(500).json({ error: 'Internal server error' });
Expand Down
21 changes: 19 additions & 2 deletions llm-server/src/model/llama-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import {
} from 'node-llama-cpp';
import { ModelProvider } from './model-provider.js';
import { Logger } from '@nestjs/common';
import { systemPrompts } from '../prompt/systemPrompt';
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions';
import { GenerateMessageParams } from '../type/GenerateMessage';

//TODO: using protocol class
export class LlamaModelProvider extends ModelProvider {
Expand All @@ -33,7 +36,7 @@ export class LlamaModelProvider extends ModelProvider {
}

async generateStreamingResponse(
content: string,
{ model, message, role = 'user' }: GenerateMessageParams,
res: Response,
): Promise<void> {
this.logger.log('Generating streaming response with Llama...');
Expand All @@ -43,8 +46,22 @@ export class LlamaModelProvider extends ModelProvider {
this.logger.log('LlamaChatSession created.');
let chunkCount = 0;
const startTime = Date.now();

// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';

const messages = [
{ role: 'system', content: systemPrompt },
{ role: role as 'user' | 'system' | 'assistant', content: message },
];

// Convert messages array to a single formatted string for Llama
const formattedPrompt = messages
.map(({ role, content }) => `${role}: ${content}`)
.join('\n');

try {
await session.prompt(content, {
await session.prompt(formattedPrompt, {
onTextChunk: chunk => {
chunkCount++;
this.logger.debug(`Sending chunk #${chunkCount}: "${chunk}"`);
Expand Down
3 changes: 2 additions & 1 deletion llm-server/src/model/model-provider.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Response } from 'express';
import { GenerateMessageParams } from '../type/GenerateMessage';

export abstract class ModelProvider {
abstract initialize(): Promise<void>;
abstract generateStreamingResponse(
content: string,
params: GenerateMessageParams,
res: Response,
): Promise<void>;

Expand Down
28 changes: 21 additions & 7 deletions llm-server/src/model/openai-model-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import { Response } from 'express';
import OpenAI from 'openai';
import { ModelProvider } from './model-provider';
import { Logger } from '@nestjs/common';
import { systemPrompts } from '../prompt/systemPrompt';
import { ChatCompletionMessageParam } from 'openai/resources/chat/completions';
import { GenerateMessageParams } from '../type/GenerateMessage';

export class OpenAIModelProvider extends ModelProvider {
private readonly logger = new Logger(OpenAIModelProvider.name);
private openai: OpenAI;
Expand All @@ -15,23 +19,34 @@ export class OpenAIModelProvider extends ModelProvider {
}

async generateStreamingResponse(
content: string,
{ model, message, role = 'user' }: GenerateMessageParams,
res: Response,
): Promise<void> {
this.logger.log('Generating streaming response with OpenAI...');
const startTime = Date.now();

// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});

// Get the system prompt based on the model
const systemPrompt = systemPrompts['codefox-basic']?.systemPrompt || '';

const messages: ChatCompletionMessageParam[] = [
{ role: 'system', content: systemPrompt },
{ role: role as 'user' | 'system' | 'assistant', content: message },
];

try {
const stream = await this.openai.chat.completions.create({
model: 'gpt-3.5-turbo',
messages: [{ role: 'user', content: content }],
model,
messages,
stream: true,
});

let chunkCount = 0;
for await (const chunk of stream) {
const content = chunk.choices[0]?.delta?.content || '';
Expand All @@ -41,6 +56,7 @@ export class OpenAIModelProvider extends ModelProvider {
res.write(`data: ${JSON.stringify(chunk)}\n\n`);
}
}

const endTime = Date.now();
this.logger.log(
`Response generation completed. Total chunks: ${chunkCount}`,
Expand All @@ -59,20 +75,18 @@ export class OpenAIModelProvider extends ModelProvider {

async getModelTagsResponse(res: Response): Promise<void> {
this.logger.log('Fetching available models from OpenAI...');
// Set SSE headers
res.writeHead(200, {
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});

try {
const startTime = Date.now();
const models = await this.openai.models.list();

const response = {
models: models, // Wrap the models in the required structure
models: models,
};

const endTime = Date.now();
this.logger.log(
`Model fetching completed. Total models: ${models.data.length}`,
Expand Down
7 changes: 7 additions & 0 deletions llm-server/src/prompt/systemPrompt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
// Define and export the system prompts object
export const systemPrompts = {
'codefox-basic': {
systemPrompt: `You are CodeFox, an advanced and powerful AI specialized in code generation and software engineering.
Your purpose is to help developers build complete and efficient applications by providing well-structured, optimized, and maintainable code.`,
},
};
5 changes: 5 additions & 0 deletions llm-server/src/type/GenerateMessage.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
export interface GenerateMessageParams {
model: string; // Model to use, e.g., 'gpt-3.5-turbo'
message: string; // User's message or query
role?: 'user' | 'system' | 'assistant' | 'tool' | 'function'; // Optional role
}

0 comments on commit 1d706f1

Please sign in to comment.