Skip to content

Commit

Permalink
Merge pull request #8 from So-Sahari/feature/store-chat-history
Browse files Browse the repository at this point in the history
Feature/store chat history
  • Loading branch information
catpaladin authored Jul 12, 2024
2 parents e57facb + 05b9757 commit dd0e6fd
Show file tree
Hide file tree
Showing 13 changed files with 491 additions and 100 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
tmp/*
jenn-ai
chat.db
25 changes: 25 additions & 0 deletions app/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package app

type ModelConfig struct {
Platform string
ModelID string

Temperature float64
TopP float64
TopK int
MaxTokens int

Region string // used for AWS models
}

func NewModelConfig(platform, modelID, region string, temp, topP float64, topK, maxTokens int) ModelConfig {
return ModelConfig{
Platform: platform,
ModelID: modelID,
Temperature: temp,
TopP: topP,
TopK: topK,
MaxTokens: maxTokens,
Region: region,
}
}
150 changes: 150 additions & 0 deletions app/database.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package app

import (
"database/sql"
"log"

_ "github.com/mattn/go-sqlite3"
)

var db *sql.DB

func initDB() {
var err error
db, err = sql.Open("sqlite3", "./chat.db")
if err != nil {
log.Fatal(err)
}

sqlStmt := `
CREATE TABLE IF NOT EXISTS conversations (
id INTEGER PRIMARY KEY AUTOINCREMENT
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
conversation_id INTEGER NOT NULL,
human TEXT,
response TEXT,
platform TEXT,
model TEXT,
FOREIGN KEY (conversation_id) REFERENCES conversations(id)
);
`
_, err = db.Exec(sqlStmt)
if err != nil {
log.Fatal(err)
}
}

func insertMessage(conversationID int, human, response, platform, model string) error {
stmt, err := db.Prepare("INSERT INTO messages(conversation_id, human, response, platform, model) VALUES(?, ?, ?, ?, ?)")
if err != nil {
return err
}
defer stmt.Close()

_, err = stmt.Exec(conversationID, human, response, platform, model)
return err
}

func getMessagesByConversationID(conversationID int) ([]Message, error) {
rows, err := db.Query(`
SELECT id, conversation_id, COALESCE(human, ''), COALESCE(response, ''), platform, model
FROM messages
WHERE conversation_id = ?
ORDER BY id ASC`, conversationID)
if err != nil {
return nil, err
}
defer rows.Close()

var messages []Message
for rows.Next() {
var msg Message
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Human, &msg.Response, &msg.Platform, &msg.Model); err != nil {
log.Printf("Error scanning row: %v", err)
return nil, err
}
messages = append(messages, msg)
}
return messages, nil
}

func getAllConversations() ([]Conversation, error) {
rows, err := db.Query(`
SELECT c.id, COALESCE(m.human, ''), COALESCE(m.response, '')
FROM conversations c
LEFT JOIN messages m ON m.id = (
SELECT id FROM messages
WHERE conversation_id = c.id
ORDER BY id DESC
LIMIT 1
)
ORDER BY c.id
`)
if err != nil {
log.Printf("Error executing query: %v", err)
return nil, err
}
defer rows.Close()

var conversations []Conversation
for rows.Next() {
var conv Conversation
if err := rows.Scan(&conv.ID, &conv.LatestHuman, &conv.LatestResponse); err != nil {
log.Printf("Error scanning row: %v", err)
return nil, err
}
conversations = append(conversations, conv)
}
return conversations, nil
}

func getMessageByID(id int) (Message, error) {
var msg Message
err := db.QueryRow("SELECT id, conversation_id, human, response, platform, model FROM messages WHERE id = ?", id).Scan(&msg.ID, &msg.ConversationID, &msg.Human, &msg.Response, &msg.Platform, &msg.Model)
return msg, err
}

func createNewConversation() (int, error) {
stmt, err := db.Prepare("INSERT INTO conversations DEFAULT VALUES RETURNING id")
if err != nil {
return 0, err
}
defer stmt.Close()

var conversationID int
err = stmt.QueryRow().Scan(&conversationID)
if err != nil {
return 0, err
}
return conversationID, nil
}

func deleteConversation(conversationID int) error {
// Delete associated messages first
_, err := db.Exec("DELETE FROM messages WHERE conversation_id = ?", conversationID)
if err != nil {
return err
}

// Delete the conversation
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", conversationID)
return err
}

type Message struct {
ID int `json:"id"`
ConversationID int `json:"conversation_id"`
Human string `json:"human"`
Response string `json:"response"`
Platform string `json:"platform"`
Model string `json:"model"`
}

type Conversation struct {
ID int `json:"id"`
LatestHuman string `json:"latest_human"`
LatestResponse string `json:"latest_response"`
}
124 changes: 124 additions & 0 deletions app/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package app

import (
"html/template"
"log"
"net/http"
"strconv"
"strings"

"jenn-ai/internal/parser"
"jenn-ai/internal/state"

"github.com/gin-gonic/gin"
)

type ChatMessage struct {
Human template.HTML
Response template.HTML
Platform string
Model string
}

func createConversation(c *gin.Context) {
conversationID, err := createNewConversation()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create a new conversation"})
return
}

appState := state.GetState()
appState.SetConversationID(conversationID)

c.HTML(http.StatusOK, "chat.html", gin.H{
"ChatMessages": []ChatMessage{},
"Platform": appState.GetPlatform(),
"Model": appState.GetModel(),
})
}

func getMessagesFromDB(c *gin.Context) {
conversationID, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid conversation ID"})
return
}

messages, err := getMessagesByConversationID(conversationID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve messages"})
return
}

appState := state.GetState()
if len(messages) > 0 {
lastMessage := messages[len(messages)-1]
appState.SetPlatform(lastMessage.Platform)
appState.SetModel(lastMessage.Model)
appState.SetConversationID(lastMessage.ConversationID)
}

var chatMessages []ChatMessage
for _, msg := range messages {
// parse markdown
parsed, err := parser.ParseMD(msg.Response)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

parsed = strings.ReplaceAll(parsed, "<pre>", "<div class='card bg-base-100 shadow-xl'><div class='card-body text-white'><pre>")
parsed = strings.ReplaceAll(parsed, "</pre>", "</pre></div></div>")

chatMessages = append(chatMessages, ChatMessage{
Human: template.HTML(msg.Human),
Response: template.HTML(parsed),
Platform: msg.Platform,
Model: msg.Model,
})
}

c.HTML(http.StatusOK, "chat.html", gin.H{
"ChatMessages": chatMessages,
})
}

func getAllMessagesFromDB(c *gin.Context) {
conversations, err := getAllConversations()
if err != nil {
log.Print(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.HTML(http.StatusOK, "sidebar.html", gin.H{
"Conversations": conversations,
})
}

func getAllConversationsHandler(c *gin.Context) {
conversations, err := getAllConversations()
if err != nil {
log.Print(err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.HTML(http.StatusOK, "sidebar.html", gin.H{
"Conversations": conversations,
})
}

func deleteChat(c *gin.Context) {
conversationID, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid conversation ID"})
return
}

err = deleteConversation(conversationID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete conversation"})
return
}

c.JSON(http.StatusOK, gin.H{"status": "Conversation deleted"})
}
Loading

0 comments on commit dd0e6fd

Please sign in to comment.