Skip to content

Commit

Permalink
refactor: Refactor storage and new serve template
Browse files Browse the repository at this point in the history
  • Loading branch information
fangyinc committed Dec 18, 2023
1 parent 9e6a26b commit be82161
Show file tree
Hide file tree
Showing 50 changed files with 1,491 additions and 40 deletions.
6 changes: 3 additions & 3 deletions dbgpt/agent/db/my_plugin_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class MyPluginEntity(Model):
UniqueConstraint("user_code", "name", name="uk_name")


class MyPluginDao(BaseDao[MyPluginEntity]):
class MyPluginDao(BaseDao):
def add(self, engity: MyPluginEntity):
session = self.get_raw_session()
my_plugin = MyPluginEntity(
Expand All @@ -49,7 +49,7 @@ def add(self, engity: MyPluginEntity):
session.close()
return id

def update(self, entity: MyPluginEntity):
def raw_update(self, entity: MyPluginEntity):
session = self.get_raw_session()
updated = session.merge(entity)
session.commit()
Expand Down Expand Up @@ -124,7 +124,7 @@ def count(self, query: MyPluginEntity):
session.close()
return count

def delete(self, plugin_id: int):
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
Expand Down
6 changes: 3 additions & 3 deletions dbgpt/agent/db/plugin_hub_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PluginHubEntity(Model):
Index("idx_q_type", "type")


class PluginHubDao(BaseDao[PluginHubEntity]):
class PluginHubDao(BaseDao):
def add(self, engity: PluginHubEntity):
session = self.get_raw_session()
timezone = pytz.timezone("Asia/Shanghai")
Expand All @@ -52,7 +52,7 @@ def add(self, engity: PluginHubEntity):
session.close()
return id

def update(self, entity: PluginHubEntity):
def raw_update(self, entity: PluginHubEntity):
session = self.get_raw_session()
try:
updated = session.merge(entity)
Expand Down Expand Up @@ -127,7 +127,7 @@ def count(self, query: PluginHubEntity):
session.close()
return count

def delete(self, plugin_id: int):
def raw_delete(self, plugin_id: int):
session = self.get_raw_session()
if plugin_id is None:
raise Exception("plugin_id is None")
Expand Down
4 changes: 2 additions & 2 deletions dbgpt/agent/hub/agent_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def refresh_hub_from_git(
plugin_hub_info.name = git_plugin._name
plugin_hub_info.version = git_plugin._version
plugin_hub_info.description = git_plugin._description
self.hub_dao.update(plugin_hub_info)
self.hub_dao.raw_update(plugin_hub_info)
except Exception as e:
raise ValueError(f"Update Agent Hub Db Info Faild!{str(e)}")

Expand Down Expand Up @@ -194,7 +194,7 @@ async def upload_my_plugin(self, doc_file: UploadFile, user: Any = Default_User)
my_plugin_entiy.user_name = user
my_plugin_entiy.tenant = ""
my_plugin_entiy.file_name = doc_file.filename
self.my_plugin_dao.update(my_plugin_entiy)
self.my_plugin_dao.raw_update(my_plugin_entiy)

def reload_my_plugins(self):
logger.info(f"load_plugins start!")
Expand Down
19 changes: 15 additions & 4 deletions dbgpt/app/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,26 @@ def _initialize_db_storage(param: "WebServerParameters"):
Now just support sqlite and mysql. If db type is sqlite, the db path is `pilot/meta_data/{db_name}.db`.
"""
_initialize_db(try_to_create_db=not param.disable_alembic_upgrade)


def _migration_db_storage(param: "WebServerParameters"):
"""Migration the db storage."""
# Import all models to make sure they are registered with SQLAlchemy.
from dbgpt.app.initialization.db_model_initialization import _MODELS

default_meta_data_path = _initialize_db(
try_to_create_db=not param.disable_alembic_upgrade
)
from dbgpt.configs.model_config import PILOT_PATH

default_meta_data_path = os.path.join(PILOT_PATH, "meta_data")
if not param.disable_alembic_upgrade:
from dbgpt.util._db_migration_utils import _ddl_init_and_upgrade
from dbgpt.storage.metadata.db_manager import db

# try to create all tables
try:
db.create_all()
except Exception as e:
logger.warning(f"Create all tables stored in this metadata error: {str(e)}")
_ddl_init_and_upgrade(default_meta_data_path, param.disable_alembic_upgrade)


Expand Down Expand Up @@ -136,7 +147,7 @@ def _initialize_db(try_to_create_db: Optional[bool] = False) -> str:
"pool_recycle": 3600,
"pool_pre_ping": True,
}
initialize_db(db_url, db_name, engine_args, try_to_create_db=try_to_create_db)
initialize_db(db_url, db_name, engine_args)
return default_meta_data_path


Expand Down
3 changes: 3 additions & 0 deletions dbgpt/app/component_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def initialize_components(
# Lazy import to avoid high time cost
from dbgpt.model.cluster.controller.controller import controller
from dbgpt.app.initialization.embedding_component import _initialize_embedding_model
from dbgpt.app.initialization.serve_initialization import register_serve_apps

# Register global default executor factory first
system_app.register(DefaultExecutorFactory)
Expand All @@ -42,6 +43,8 @@ def initialize_components(
)
_initialize_model_cache(system_app)
_initialize_awel(system_app)
# Register serve apps
register_serve_apps(system_app)


def _initialize_model_cache(system_app: SystemApp):
Expand Down
15 changes: 11 additions & 4 deletions dbgpt/app/dbgpt_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from dbgpt.app.base import (
server_init,
_migration_db_storage,
WebServerParameters,
_create_model_start_listener,
)
Expand Down Expand Up @@ -75,7 +76,9 @@ def swagger_monkey_patch(*args, **kwargs):
def mount_routers(app: FastAPI):
"""Lazy import to avoid high time cost"""
from dbgpt.app.knowledge.api import router as knowledge_router
from dbgpt.app.prompt.api import router as prompt_router

# from dbgpt.app.prompt.api import router as prompt_router
# prompt has been removed to dbgpt.serve.prompt
from dbgpt.app.llm_manage.api import router as llm_manage_api

from dbgpt.app.openapi.api_v1.api_v1 import router as api_v1
Expand All @@ -90,7 +93,7 @@ def mount_routers(app: FastAPI):
app.include_router(api_fb_v1, prefix="/api", tags=["FeedBack"])

app.include_router(knowledge_router, tags=["Knowledge"])
app.include_router(prompt_router, tags=["Prompt"])
# app.include_router(prompt_router, tags=["Prompt"])


def mount_static_files(app: FastAPI):
Expand Down Expand Up @@ -141,8 +144,6 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
"dbgpt", logging_level=param.log_level, logger_filename=param.log_file
)

# Before start
system_app.before_start()
model_name = param.model_name or CFG.LLM_MODEL
param.model_name = model_name
print(param)
Expand All @@ -155,6 +156,12 @@ def initialize_app(param: WebServerParameters = None, args: List[str] = None):
model_start_listener = _create_model_start_listener(system_app)
initialize_components(param, system_app, embedding_model_name, embedding_model_path)

# Before start, after initialize_components
# TODO: initialize_worker_manager_in_client as a component register in system_app
system_app.before_start()
# Migration db storage, so you db models must be imported before this
_migration_db_storage(param)

model_path = CFG.LLM_MODEL_PATH or LLM_MODEL_CONFIG.get(model_name)
if not param.light:
print("Model Unified Deployment Mode!")
Expand Down
5 changes: 3 additions & 2 deletions dbgpt/app/initialization/db_model_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from dbgpt.app.knowledge.document_db import KnowledgeDocumentEntity
from dbgpt.app.knowledge.space_db import KnowledgeSpaceEntity
from dbgpt.app.openapi.api_v1.feedback.feed_back_db import ChatFeedBackEntity
from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity

# from dbgpt.app.prompt.prompt_manage_db import PromptManageEntity
from dbgpt.datasource.manages.connect_config_db import ConnectConfigEntity
from dbgpt.storage.chat_history.chat_history_db import (
ChatHistoryEntity,
Expand All @@ -16,7 +17,7 @@
_MODELS = [
PluginHubEntity,
MyPluginEntity,
PromptManageEntity,
# PromptManageEntity,
KnowledgeSpaceEntity,
KnowledgeDocumentEntity,
DocumentChunkEntity,
Expand Down
9 changes: 9 additions & 0 deletions dbgpt/app/initialization/serve_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dbgpt.component import SystemApp


def register_serve_apps(system_app: SystemApp):
"""Register serve apps"""
from dbgpt.serve.prompt.serve import Serve as PromptServe

# Replace old prompt serve
system_app.register(PromptServe, api_prefix="/prompt")
2 changes: 1 addition & 1 deletion dbgpt/app/knowledge/chunk_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def get_document_chunks_count(self, query: DocumentChunkEntity):
session.close()
return count

def delete(self, document_id: int):
def raw_delete(self, document_id: int):
session = self.get_raw_session()
if document_id is None:
raise Exception("document_id is None")
Expand Down
2 changes: 1 addition & 1 deletion dbgpt/app/knowledge/document_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def update_knowledge_document(self, document: KnowledgeDocumentEntity):
return updated_space.id

#
def delete(self, query: KnowledgeDocumentEntity):
def raw_delete(self, query: KnowledgeDocumentEntity):
session = self.get_raw_session()
knowledge_documents = session.query(KnowledgeDocumentEntity)
if query.id is not None:
Expand Down
8 changes: 4 additions & 4 deletions dbgpt/app/knowledge/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,9 @@ def delete_space(self, space_name: str):
# delete chunks
documents = knowledge_document_dao.get_documents(document_query)
for document in documents:
document_chunk_dao.delete(document.id)
document_chunk_dao.raw_delete(document.id)
# delete documents
knowledge_document_dao.delete(document_query)
knowledge_document_dao.raw_delete(document_query)
# delete space
return knowledge_space_dao.delete_knowledge_space(space)

Expand All @@ -395,9 +395,9 @@ def delete_document(self, space_name: str, doc_name: str):
# delete vector by ids
vector_client.delete_by_ids(vector_ids)
# delete chunks
document_chunk_dao.delete(documents[0].id)
document_chunk_dao.raw_delete(documents[0].id)
# delete document
return knowledge_document_dao.delete(document_query)
return knowledge_document_dao.raw_delete(document_query)

def get_document_chunks(self, request: ChunkQueryRequest):
"""get document chunks
Expand Down
14 changes: 14 additions & 0 deletions dbgpt/cli/cli_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def db():
pass


@click.group()
def new():
"""New a template"""
pass


stop_all_func_list = []


Expand All @@ -71,6 +77,7 @@ def stop_all():
cli.add_command(stop)
cli.add_command(install)
cli.add_command(db)
cli.add_command(new)
add_command_alias(stop_all, name="all", parent_group=stop)

try:
Expand Down Expand Up @@ -130,6 +137,13 @@ def stop_all():
except ImportError as e:
logging.warning(f"Integrating dbgpt trace command line tool failed: {e}")

try:
from dbgpt.serve.utils.cli import serve

add_command_alias(serve, name="serve", parent_group=new)
except ImportError as e:
logging.warning(f"Integrating dbgpt serve command line tool failed: {e}")


def main():
return cli()
Expand Down
13 changes: 12 additions & 1 deletion dbgpt/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
import asyncio
from dbgpt.util.annotations import PublicAPI
from dbgpt.util import AppConfig

# Checking for type hints during runtime
if TYPE_CHECKING:
Expand Down Expand Up @@ -87,17 +88,27 @@ def init_app(self, system_app: SystemApp):
class SystemApp(LifeCycle):
"""Main System Application class that manages the lifecycle and registration of components."""

def __init__(self, asgi_app: Optional["FastAPI"] = None) -> None:
def __init__(
self,
asgi_app: Optional["FastAPI"] = None,
app_config: Optional[AppConfig] = None,
) -> None:
self.components: Dict[
str, BaseComponent
] = {} # Dictionary to store registered components.
self._asgi_app = asgi_app
self._app_config = app_config or AppConfig()

@property
def app(self) -> Optional["FastAPI"]:
"""Returns the internal ASGI app."""
return self._asgi_app

@property
def config(self) -> AppConfig:
"""Returns the internal AppConfig."""
return self._app_config

def register(self, component: Type[BaseComponent], *args, **kwargs) -> T:
"""Register a new component by its type.
Expand Down
10 changes: 9 additions & 1 deletion dbgpt/core/interface/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,11 @@ def __init__(
super().__init__(chat_mode, user_name, sys_code, summary, **kwargs)
self.conv_uid = conv_uid
self._message_ids = message_ids
# Record the message index last time saved to the storage,
# next time save messages which index is _has_stored_message_index + 1
self._has_stored_message_index = (
len(kwargs["messages"]) - 1 if "messages" in kwargs else -1
)
self.save_message_independent = save_message_independent
self._id = ConversationIdentifier(conv_uid)
if conv_storage is None:
Expand Down Expand Up @@ -695,7 +700,9 @@ def save_to_storage(self) -> None:
self._message_ids = [
message.identifier.str_identifier for message in message_list
]
self.message_storage.save_list(message_list)
messages_to_save = message_list[self._has_stored_message_index + 1 :]
self._has_stored_message_index = len(message_list) - 1
self.message_storage.save_list(messages_to_save)
# Save conversation
self.conv_storage.save_or_update(self)

Expand Down Expand Up @@ -729,6 +736,7 @@ def load_from_storage(
messages = [message.to_message() for message in message_list]
conversation.messages = messages
self._message_ids = message_ids
self._has_stored_message_index = len(messages) - 1
self.from_conversation(conversation)


Expand Down
2 changes: 1 addition & 1 deletion dbgpt/datasource/manages/connect_config_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ConnectConfigEntity(Model):
)


class ConnectConfigDao(BaseDao[ConnectConfigEntity]):
class ConnectConfigDao(BaseDao):
"""db connect config dao"""

def update(self, entity: ConnectConfigEntity):
Expand Down
5 changes: 5 additions & 0 deletions dbgpt/serve/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from dbgpt.serve.core.schemas import Result
from dbgpt.serve.core.config import BaseServeConfig
from dbgpt.serve.core.service import BaseService

__ALL__ = ["Result", "BaseServeConfig", "BaseService"]
19 changes: 19 additions & 0 deletions dbgpt/serve/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dataclasses import dataclass
from dbgpt.component import AppConfig
from dbgpt.util import BaseParameters


@dataclass
class BaseServeConfig(BaseParameters):
"""Base configuration class for serve"""

@classmethod
def from_app_config(cls, config: AppConfig, config_prefix: str):
"""Create a configuration object from a dictionary
Args:
config (AppConfig): Application configuration
config_prefix (str): Configuration prefix
"""
config_dict = config.get_all_by_prefix(config_prefix)
return cls(**config_dict)
Loading

0 comments on commit be82161

Please sign in to comment.