-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from So-Sahari/feature/store-chat-history
Feature/store chat history
- Loading branch information
Showing
13 changed files
with
491 additions
and
100 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
tmp/* | ||
jenn-ai | ||
chat.db |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
package app | ||
|
||
type ModelConfig struct { | ||
Platform string | ||
ModelID string | ||
|
||
Temperature float64 | ||
TopP float64 | ||
TopK int | ||
MaxTokens int | ||
|
||
Region string // used for AWS models | ||
} | ||
|
||
func NewModelConfig(platform, modelID, region string, temp, topP float64, topK, maxTokens int) ModelConfig { | ||
return ModelConfig{ | ||
Platform: platform, | ||
ModelID: modelID, | ||
Temperature: temp, | ||
TopP: topP, | ||
TopK: topK, | ||
MaxTokens: maxTokens, | ||
Region: region, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
package app | ||
|
||
import ( | ||
"database/sql" | ||
"log" | ||
|
||
_ "github.com/mattn/go-sqlite3" | ||
) | ||
|
||
var db *sql.DB | ||
|
||
func initDB() { | ||
var err error | ||
db, err = sql.Open("sqlite3", "./chat.db") | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
sqlStmt := ` | ||
CREATE TABLE IF NOT EXISTS conversations ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT | ||
); | ||
CREATE TABLE IF NOT EXISTS messages ( | ||
id INTEGER PRIMARY KEY AUTOINCREMENT, | ||
conversation_id INTEGER NOT NULL, | ||
human TEXT, | ||
response TEXT, | ||
platform TEXT, | ||
model TEXT, | ||
FOREIGN KEY (conversation_id) REFERENCES conversations(id) | ||
); | ||
` | ||
_, err = db.Exec(sqlStmt) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
} | ||
|
||
func insertMessage(conversationID int, human, response, platform, model string) error { | ||
stmt, err := db.Prepare("INSERT INTO messages(conversation_id, human, response, platform, model) VALUES(?, ?, ?, ?, ?)") | ||
if err != nil { | ||
return err | ||
} | ||
defer stmt.Close() | ||
|
||
_, err = stmt.Exec(conversationID, human, response, platform, model) | ||
return err | ||
} | ||
|
||
func getMessagesByConversationID(conversationID int) ([]Message, error) { | ||
rows, err := db.Query(` | ||
SELECT id, conversation_id, COALESCE(human, ''), COALESCE(response, ''), platform, model | ||
FROM messages | ||
WHERE conversation_id = ? | ||
ORDER BY id ASC`, conversationID) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer rows.Close() | ||
|
||
var messages []Message | ||
for rows.Next() { | ||
var msg Message | ||
if err := rows.Scan(&msg.ID, &msg.ConversationID, &msg.Human, &msg.Response, &msg.Platform, &msg.Model); err != nil { | ||
log.Printf("Error scanning row: %v", err) | ||
return nil, err | ||
} | ||
messages = append(messages, msg) | ||
} | ||
return messages, nil | ||
} | ||
|
||
func getAllConversations() ([]Conversation, error) { | ||
rows, err := db.Query(` | ||
SELECT c.id, COALESCE(m.human, ''), COALESCE(m.response, '') | ||
FROM conversations c | ||
LEFT JOIN messages m ON m.id = ( | ||
SELECT id FROM messages | ||
WHERE conversation_id = c.id | ||
ORDER BY id DESC | ||
LIMIT 1 | ||
) | ||
ORDER BY c.id | ||
`) | ||
if err != nil { | ||
log.Printf("Error executing query: %v", err) | ||
return nil, err | ||
} | ||
defer rows.Close() | ||
|
||
var conversations []Conversation | ||
for rows.Next() { | ||
var conv Conversation | ||
if err := rows.Scan(&conv.ID, &conv.LatestHuman, &conv.LatestResponse); err != nil { | ||
log.Printf("Error scanning row: %v", err) | ||
return nil, err | ||
} | ||
conversations = append(conversations, conv) | ||
} | ||
return conversations, nil | ||
} | ||
|
||
func getMessageByID(id int) (Message, error) { | ||
var msg Message | ||
err := db.QueryRow("SELECT id, conversation_id, human, response, platform, model FROM messages WHERE id = ?", id).Scan(&msg.ID, &msg.ConversationID, &msg.Human, &msg.Response, &msg.Platform, &msg.Model) | ||
return msg, err | ||
} | ||
|
||
func createNewConversation() (int, error) { | ||
stmt, err := db.Prepare("INSERT INTO conversations DEFAULT VALUES RETURNING id") | ||
if err != nil { | ||
return 0, err | ||
} | ||
defer stmt.Close() | ||
|
||
var conversationID int | ||
err = stmt.QueryRow().Scan(&conversationID) | ||
if err != nil { | ||
return 0, err | ||
} | ||
return conversationID, nil | ||
} | ||
|
||
func deleteConversation(conversationID int) error { | ||
// Delete associated messages first | ||
_, err := db.Exec("DELETE FROM messages WHERE conversation_id = ?", conversationID) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
// Delete the conversation | ||
_, err = db.Exec("DELETE FROM conversations WHERE id = ?", conversationID) | ||
return err | ||
} | ||
|
||
type Message struct { | ||
ID int `json:"id"` | ||
ConversationID int `json:"conversation_id"` | ||
Human string `json:"human"` | ||
Response string `json:"response"` | ||
Platform string `json:"platform"` | ||
Model string `json:"model"` | ||
} | ||
|
||
type Conversation struct { | ||
ID int `json:"id"` | ||
LatestHuman string `json:"latest_human"` | ||
LatestResponse string `json:"latest_response"` | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
package app | ||
|
||
import ( | ||
"html/template" | ||
"log" | ||
"net/http" | ||
"strconv" | ||
"strings" | ||
|
||
"jenn-ai/internal/parser" | ||
"jenn-ai/internal/state" | ||
|
||
"github.com/gin-gonic/gin" | ||
) | ||
|
||
type ChatMessage struct { | ||
Human template.HTML | ||
Response template.HTML | ||
Platform string | ||
Model string | ||
} | ||
|
||
func createConversation(c *gin.Context) { | ||
conversationID, err := createNewConversation() | ||
if err != nil { | ||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create a new conversation"}) | ||
return | ||
} | ||
|
||
appState := state.GetState() | ||
appState.SetConversationID(conversationID) | ||
|
||
c.HTML(http.StatusOK, "chat.html", gin.H{ | ||
"ChatMessages": []ChatMessage{}, | ||
"Platform": appState.GetPlatform(), | ||
"Model": appState.GetModel(), | ||
}) | ||
} | ||
|
||
func getMessagesFromDB(c *gin.Context) { | ||
conversationID, err := strconv.Atoi(c.Param("id")) | ||
if err != nil { | ||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid conversation ID"}) | ||
return | ||
} | ||
|
||
messages, err := getMessagesByConversationID(conversationID) | ||
if err != nil { | ||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve messages"}) | ||
return | ||
} | ||
|
||
appState := state.GetState() | ||
if len(messages) > 0 { | ||
lastMessage := messages[len(messages)-1] | ||
appState.SetPlatform(lastMessage.Platform) | ||
appState.SetModel(lastMessage.Model) | ||
appState.SetConversationID(lastMessage.ConversationID) | ||
} | ||
|
||
var chatMessages []ChatMessage | ||
for _, msg := range messages { | ||
// parse markdown | ||
parsed, err := parser.ParseMD(msg.Response) | ||
if err != nil { | ||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) | ||
return | ||
} | ||
|
||
parsed = strings.ReplaceAll(parsed, "<pre>", "<div class='card bg-base-100 shadow-xl'><div class='card-body text-white'><pre>") | ||
parsed = strings.ReplaceAll(parsed, "</pre>", "</pre></div></div>") | ||
|
||
chatMessages = append(chatMessages, ChatMessage{ | ||
Human: template.HTML(msg.Human), | ||
Response: template.HTML(parsed), | ||
Platform: msg.Platform, | ||
Model: msg.Model, | ||
}) | ||
} | ||
|
||
c.HTML(http.StatusOK, "chat.html", gin.H{ | ||
"ChatMessages": chatMessages, | ||
}) | ||
} | ||
|
||
func getAllMessagesFromDB(c *gin.Context) { | ||
conversations, err := getAllConversations() | ||
if err != nil { | ||
log.Print(err) | ||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||
return | ||
} | ||
c.HTML(http.StatusOK, "sidebar.html", gin.H{ | ||
"Conversations": conversations, | ||
}) | ||
} | ||
|
||
func getAllConversationsHandler(c *gin.Context) { | ||
conversations, err := getAllConversations() | ||
if err != nil { | ||
log.Print(err) | ||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) | ||
return | ||
} | ||
c.HTML(http.StatusOK, "sidebar.html", gin.H{ | ||
"Conversations": conversations, | ||
}) | ||
} | ||
|
||
func deleteChat(c *gin.Context) { | ||
conversationID, err := strconv.Atoi(c.Param("id")) | ||
if err != nil { | ||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid conversation ID"}) | ||
return | ||
} | ||
|
||
err = deleteConversation(conversationID) | ||
if err != nil { | ||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete conversation"}) | ||
return | ||
} | ||
|
||
c.JSON(http.StatusOK, gin.H{"status": "Conversation deleted"}) | ||
} |
Oops, something went wrong.