From 2ac36d339a2f85918abd5684b0319c56260d2336 Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Wed, 6 Dec 2023 22:31:23 +0100 Subject: [PATCH 1/5] Add filesystem abstraction to Python API --- pyproject.toml | 2 + ragna/core/__init__.py | 13 +-- ragna/core/_document.py | 202 ++++++++++++++++++++++++++++++++++---- ragna/core/_rag.py | 8 +- ragna/deploy/_api/core.py | 9 +- 5 files changed, 198 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 92fa6c9d..72e2e240 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "anyio", "emoji", "fastapi", + "fsspec", "httpx", "importlib_metadata>=4.6; python_version<'3.10'", "packaging", @@ -136,6 +137,7 @@ disallow_incomplete_defs = false [[tool.mypy.overrides]] module = [ "fitz", + "fsspec", "lancedb", "param", "pyarrow", diff --git a/ragna/core/__init__.py b/ragna/core/__init__.py index d5d7ad1b..ddf78fb0 100644 --- a/ragna/core/__init__.py +++ b/ragna/core/__init__.py @@ -5,7 +5,8 @@ "Document", "DocumentHandler", "EnvVarRequirement", - "LocalDocument", + "FilesystemDocument", + "filesystem_glob", "Message", "MessageRole", "PackageRequirement", @@ -19,22 +20,18 @@ "TxtDocumentHandler", ] -from ._utils import ( - EnvVarRequirement, - PackageRequirement, - RagnaException, - Requirement, -) +from ._utils import EnvVarRequirement, PackageRequirement, RagnaException, Requirement # isort: split from ._document import ( Document, DocumentHandler, - LocalDocument, + FilesystemDocument, Page, PdfDocumentHandler, TxtDocumentHandler, + filesystem_glob, ) # isort: split diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 7369b0a3..9a6f7288 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -6,12 +6,19 @@ import time import uuid from pathlib import Path -from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Iterator, Optional, Type, TypeVar, cast +import fsspec import jwt from pydantic import BaseModel -from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin +from ._utils import ( + EnvVarRequirement, + PackageRequirement, + RagnaException, + Requirement, + RequirementsMixin, +) if TYPE_CHECKING: from ragna.deploy import Config @@ -74,35 +81,164 @@ def extract_pages(self) -> Iterator[Page]: yield from self.handler.extract_pages(self) -class LocalDocument(Document): - """Document class for files on the local file system. +class FS(RequirementsMixin, abc.ABC): + "Abstract base class for all fsspec-like filesystems." + + # Identifier for the filesystem, e.g. "local" or "github" + _prefix: str + + @classmethod + @abc.abstractmethod + def create_fs_instance_from_key( + cls, key: str, asynchronous: bool + ) -> fsspec.AbstractFileSystem: + """Create a fsspec filesystem instance from a key.""" + ... + + @staticmethod + @abc.abstractmethod + def resolve_path(key: str) -> str: + """Resolve a key to an absolute path on a fsspec-supported filesystem.""" + ... + + +S = TypeVar("S", bound=FS) + + +class FSRegistry(dict[str, Type[FS]]): + def check_available(self, cls: Type[S]) -> Type[S]: + if cls.is_available(): + self[cls._prefix] = cls + return cls + + +FS_REGISTRY = FSRegistry() + + +@FS_REGISTRY.check_available +class LocalFS(FS): + _prefix: str = "local" + _fs_cache: dict[str, fsspec.AbstractFileSystem] = {} + + @classmethod + def create_fs_instance_from_key( + cls, key: str, asynchronous: bool + ) -> fsspec.AbstractFileSystem: + if key not in cls._fs_cache: + cls._fs_cache[key] = fsspec.filesystem( + cls._prefix, asynchronous=asynchronous + ) + return cls._fs_cache[key] + + @staticmethod + def resolve_path(key: str) -> str: + return str(Path(key).resolve()) + + +@FS_REGISTRY.check_available +class GithubFS(FS): + _prefix: str = "github" + _fs_cache: dict[str, fsspec.AbstractFileSystem] = {} + + @classmethod + def requirements(cls) -> list[Requirement]: + return [ + PackageRequirement("requests"), + EnvVarRequirement("GITHUB_USERNAME"), + EnvVarRequirement("GITHUB_TOKEN"), + ] + + @classmethod + def create_fs_instance_from_key( + cls, key: str, asynchronous: bool + ) -> fsspec.AbstractFileSystem: + # org/repo/path/to/file + org, repo, *_ = key.split("/") + if f"{org}/{repo}" not in cls._fs_cache: + cls._fs_cache[f"{org}/{repo}"] = fsspec.filesystem( + cls._prefix, + org=org, + repo=repo, + username=os.environ["GITHUB_USERNAME"], + token=os.environ["GITHUB_TOKEN"], + asynchronous=asynchronous, + ) + return cls._fs_cache[f"{org}/{repo}"] + + @staticmethod + def resolve_path(key: str) -> str: + # GitHub 'absolute' paths are relative to the repo root + _, _, *ks = key.split("/") + return "/".join(ks) + + +def filesystem_glob(path: str) -> list[str]: + """Glob for files on any filesystem supported by fsspec. + + Args: + path: Path to glob for. + + Returns: + List of paths matching the glob. + """ + try: + prefix, key = path.split("://") + except ValueError: + prefix, key = "local", path + + if prefix not in FS_REGISTRY: + raise RagnaException(f"Unavailable filesystem prefix: {prefix}") + + kls = FS_REGISTRY[prefix] + fs = kls.create_fs_instance_from_key(key, asynchronous=False) + return cast(list[str], fs.glob(kls.resolve_path(key))) + + +class FilesystemDocument(Document): + """Document class for files on any file system supported by fsspec. !!! tip This object is usually not instantiated manually, but rather through - [ragna.core.LocalDocument.from_path][]. + [ragna.core.FilesystemDocument.from_path][]. """ + def __init__( + self, + *, + id: Optional[uuid.UUID] = None, + name: str, + metadata: dict[str, Any], + handler: Optional[DocumentHandler] = None, + fs: fsspec.AbstractFileSystem = None, + ): + super().__init__(id=id, name=name, metadata=metadata, handler=handler) + if fs is None: + self.fs = fsspec.filesystem("local") + else: + self.fs = fs + @classmethod def from_path( cls, - path: Union[str, Path], + path: str, *, id: Optional[uuid.UUID] = None, metadata: Optional[dict[str, Any]] = None, handler: Optional[DocumentHandler] = None, - ) -> LocalDocument: - """Create a [ragna.core.LocalDocument][] from a path. + ) -> FilesystemDocument: + """Create a [ragna.core.FilesystemDocument][] from a path. Args: - path: Local path to the file. + path: Path to the file on the filesystem, including filesystem prefix id: ID of the document. If omitted, one is generated. metadata: Optional metadata of the document. handler: Document handler. If omitted, a builtin handler is selected based on the suffix of the `path`. Raises: - RagnaException: If `metadata` is passed and contains a `"path"` key. + RagnaException: If `metadata` is passed and contains a `"path"` key or + if the filesystem prefix is missing. """ if metadata is None: metadata = {} @@ -112,21 +248,36 @@ def from_path( "Did you mean to instantiate the class directly?" ) - path = Path(path).expanduser().resolve() - metadata["path"] = str(path) + try: + prefix, key = path.split("://") + except ValueError: + prefix, key = "local", path + + if prefix not in FS_REGISTRY: + raise RagnaException(f"Unavailable filesystem prefix: {prefix}") - return cls(id=id, name=path.name, metadata=metadata, handler=handler) + # TODO: Determine if making filesystem operations async is beneficial + kls = FS_REGISTRY[prefix] + fs = kls.create_fs_instance_from_key(key, asynchronous=False) + metadata["path"] = kls.resolve_path(key) + name = os.path.basename(metadata["path"]) + + return cls(id=id, name=name, metadata=metadata, handler=handler, fs=fs) + + @staticmethod + def supported_filesystems() -> set[str]: + return set(FS_REGISTRY.keys()) @property - def path(self) -> Path: - return Path(self.metadata["path"]) + def path(self) -> str: + return cast(str, self.metadata["path"]) def is_readable(self) -> bool: - return self.path.exists() + return cast(bool, self.fs.exists(self.path)) def read(self) -> bytes: - with open(self.path, "rb") as stream: - return stream.read() + with self.fs.open(self.path, "rb") as stream: + return cast(bytes, stream.read()) _JWT_SECRET = os.environ.get( "RAGNA_API_DOCUMENT_UPLOAD_SECRET", secrets.token_urlsafe(32)[:32] @@ -135,7 +286,13 @@ def read(self) -> bytes: @classmethod async def get_upload_info( - cls, *, config: Config, user: str, id: uuid.UUID, name: str + cls, + *, + config: Config, + user: str, + id: uuid.UUID, + name: str, + path: Optional[str] = None, ) -> tuple[str, dict[str, Any], dict[str, Any]]: url = f"{config.api.url}/document" data = { @@ -149,7 +306,12 @@ async def get_upload_info( algorithm=cls._JWT_ALGORITHM, ) } - metadata = {"path": str(config.local_cache_root / "documents" / str(id))} + if path is not None: + metadata = {"path": path} + else: + metadata = { + "path": str(config.local_cache_root / "documents" / str(id)), + } return url, data, metadata @classmethod diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index b1af0739..77449495 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -21,7 +21,7 @@ import pydantic from ._components import Assistant, Component, Message, MessageRole, SourceStorage -from ._document import Document, LocalDocument +from ._document import Document, FilesystemDocument from ._utils import RagnaException, default_user, merge_models T = TypeVar("T") @@ -71,7 +71,7 @@ def chat( Args: documents: Documents to use. If any item is not a [ragna.core.Document][], - [ragna.core.LocalDocument.from_path][] is invoked on it. + [ragna.core.FilesystemDocument.from_path][] is invoked on it. source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -120,7 +120,7 @@ class Chat: Args: rag: The RAG workflow this chat is associated with. documents: Documents to use. If any item is not a [ragna.core.Document][], - [ragna.core.LocalDocument.from_path][] is invoked on it. + [ragna.core.FilesystemDocument.from_path][] is invoked on it. source_storage: Source storage to use. assistant: Assistant to use. **params: Additional parameters passed to the source storage and assistant. @@ -225,7 +225,7 @@ def _parse_documents(self, documents: Iterable[Any]) -> list[Document]: documents_ = [] for document in documents: if not isinstance(document, Document): - document = LocalDocument.from_path(document) + document = FilesystemDocument.from_path(document) if not document.is_readable(): raise RagnaException( diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index c4abe84e..8cdd55c0 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -1,6 +1,7 @@ import contextlib import itertools import uuid +from pathlib import Path from typing import Annotated, Any, Iterator, Type, cast import aiofiles @@ -131,19 +132,19 @@ async def get_document_upload_info( async def upload_document( token: Annotated[str, Form()], file: UploadFile ) -> schemas.Document: - if not issubclass(config.document, ragna.core.LocalDocument): + if not issubclass(config.document, ragna.core.FilesystemDocument): raise HTTPException( status_code=400, detail="Ragna configuration does not support local upload", ) with get_session() as session: - user, id = ragna.core.LocalDocument.decode_upload_token(token) + user, id = ragna.core.FilesystemDocument.decode_upload_token(token) document, metadata = database.get_document(session, user=user, id=id) - core_document = ragna.core.LocalDocument( + core_document = ragna.core.FilesystemDocument( id=document.id, name=document.name, metadata=metadata ) - core_document.path.parent.mkdir(parents=True, exist_ok=True) + Path(core_document.path).parent.mkdir(parents=True, exist_ok=True) async with aiofiles.open(core_document.path, "wb") as document_file: while content := await file.read(1024): await document_file.write(content) From b6d7292b0e4a7c97d927fdb388a81ee3b223db1f Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Wed, 6 Dec 2023 23:16:31 +0100 Subject: [PATCH 2/5] Add filesystem abstraction to REST API The REST API has been updated to allow reading directly from fsspec-supported filesystems. The new workflow is to first generate metadata on the server via a GET to the '/document' endpoint. If the metadata references a document on a fsspec-supported path, then we go straight to the 'chats' endpoint. Otherwise, the workflow is as before and we must upload the document to the server via a POST to the '/document' endpoint. --- ragna/deploy/_api/core.py | 31 ++++++++++++++++++++++--------- ragna/deploy/_api/database.py | 9 +++++++-- ragna/deploy/_api/orm.py | 1 + ragna/deploy/_api/schemas.py | 4 +++- ragna/deploy/_config.py | 4 ++-- 5 files changed, 35 insertions(+), 14 deletions(-) diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 8cdd55c0..196ae151 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -2,7 +2,7 @@ import itertools import uuid from pathlib import Path -from typing import Annotated, Any, Iterator, Type, cast +from typing import Annotated, Any, Iterator, Optional, Type, cast import aiofiles from fastapi import Depends, FastAPI, Form, HTTPException, Request, UploadFile @@ -93,6 +93,7 @@ def _get_component_json_schema( async def get_components(_: UserDependency) -> schemas.Components: return schemas.Components( documents=sorted(config.document.supported_suffixes()), + filesystems=sorted(config.document.supported_filesystems()), source_storages=[ _get_component_json_schema(source_storage) for source_storage in config.components.source_storages @@ -117,11 +118,16 @@ def get_session() -> Iterator[database.Session]: async def get_document_upload_info( user: UserDependency, name: str, + prefixed_path: Optional[str] = None, ) -> schemas.DocumentUploadInfo: with get_session() as session: - document = schemas.Document(name=name) + document = schemas.Document(prefixed_path=prefixed_path, name=name) url, data, metadata = await config.document.get_upload_info( - config=config, user=user, id=document.id, name=document.name + config=config, + user=user, + id=document.id, + name=document.name, + path=prefixed_path, ) database.add_document( session, user=user, document=document, metadata=metadata @@ -154,9 +160,14 @@ async def upload_document( def schema_to_core_chat( session: database.Session, *, user: str, chat: schemas.Chat ) -> ragna.core.Chat: - core_chat = rag.chat( - documents=[ - config.document( + documents = [] + for document in chat.metadata.documents: + if document.prefixed_path: + doc = config.document.from_path( + id=document.id, path=document.prefixed_path + ) + else: + doc = config.document( id=document.id, name=document.name, metadata=database.get_document( @@ -165,8 +176,10 @@ def schema_to_core_chat( id=document.id, )[1], ) - for document in chat.metadata.documents - ], + documents.append(doc) + + core_chat = rag.chat( + documents=documents, source_storage=get_component(chat.metadata.source_storage), # type: ignore[arg-type] assistant=get_component(chat.metadata.assistant), # type: ignore[arg-type] user=user, @@ -193,7 +206,7 @@ async def create_chat( # Although we don't need the actual ragna.core.Chat object here, # we use it to validate the documents and metadata. - schema_to_core_chat(session, user=user, chat=chat) + # schema_to_core_chat(session, user=user, chat=chat) database.add_chat(session, user=user, chat=chat) return chat diff --git a/ragna/deploy/_api/database.py b/ragna/deploy/_api/database.py index 351f51bb..6ba1009f 100644 --- a/ragna/deploy/_api/database.py +++ b/ragna/deploy/_api/database.py @@ -50,13 +50,16 @@ def add_document( user_id=_get_user_id(session, user), name=document.name, metadata_=metadata, + prefixed_path=document.prefixed_path, ) ) session.commit() def _orm_to_schema_document(document: orm.Document) -> schemas.Document: - return schemas.Document(id=document.id, name=document.name) + return schemas.Document( + id=document.id, name=document.name, prefixed_path=document.prefixed_path + ) @functools.lru_cache(maxsize=1024) @@ -100,7 +103,9 @@ def add_chat(session: Session, *, user: str, chat: schemas.Chat) -> None: def _orm_to_schema_chat(chat: orm.Chat) -> schemas.Chat: documents = [ - schemas.Document(id=document.id, name=document.name) + schemas.Document( + id=document.id, name=document.name, prefixed_path=document.prefixed_path + ) for document in chat.documents ] messages = [ diff --git a/ragna/deploy/_api/orm.py b/ragna/deploy/_api/orm.py index 5911a362..8c63db0d 100644 --- a/ragna/deploy/_api/orm.py +++ b/ragna/deploy/_api/orm.py @@ -32,6 +32,7 @@ class Document(Base): id = Column(types.Uuid, primary_key=True) # type: ignore[attr-defined] user_id = Column(ForeignKey("users.id")) name = Column(types.String) + prefixed_path = Column(types.String) # Mind the trailing underscore here. Unfortunately, this is necessary, because # metadata without the underscore is reserved by SQLAlchemy metadata_ = Column(types.JSON) diff --git a/ragna/deploy/_api/schemas.py b/ragna/deploy/_api/schemas.py index ed7a9be3..fec41f44 100644 --- a/ragna/deploy/_api/schemas.py +++ b/ragna/deploy/_api/schemas.py @@ -2,7 +2,7 @@ import datetime import uuid -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, Field @@ -10,6 +10,7 @@ class Components(BaseModel): + filesystems: list[str] documents: list[str] source_storages: list[dict[str, Any]] assistants: list[dict[str, Any]] @@ -18,6 +19,7 @@ class Components(BaseModel): class Document(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) name: str + prefixed_path: Optional[str] = None @classmethod def from_core(cls, document: ragna.core.Document) -> Document: diff --git a/ragna/deploy/_config.py b/ragna/deploy/_config.py index 11bbd1e3..c0a261b5 100644 --- a/ragna/deploy/_config.py +++ b/ragna/deploy/_config.py @@ -11,7 +11,7 @@ SettingsConfigDict, ) -from ragna.core import Assistant, Document, RagnaException, SourceStorage +from ragna.core import Assistant, FilesystemDocument, RagnaException, SourceStorage from ._authentication import Authentication @@ -80,7 +80,7 @@ class Config(ConfigBase): default_factory=lambda: Path.home() / ".cache" / "ragna" ) - document: ImportString[type[Document]] = "ragna.core.LocalDocument" # type: ignore[assignment] + document: ImportString[type[FilesystemDocument]] = "ragna.core.FilesystemDocument" # type: ignore[assignment] authentication: ImportString[ type[Authentication] From dfc7d810f4fa05f744304051627a48fd75d07418 Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Thu, 7 Dec 2023 12:30:10 +0100 Subject: [PATCH 3/5] Add examples for using REST and Py APIs --- examples/python_api/fsspec_demo.ipynb | 110 +++++++++++++++ examples/rest_api/fsspec_demo.ipynb | 192 ++++++++++++++++++++++++++ 2 files changed, 302 insertions(+) create mode 100644 examples/python_api/fsspec_demo.ipynb create mode 100644 examples/rest_api/fsspec_demo.ipynb diff --git a/examples/python_api/fsspec_demo.ipynb b/examples/python_api/fsspec_demo.ipynb new file mode 100644 index 00000000..22fee068 --- /dev/null +++ b/examples/python_api/fsspec_demo.ipynb @@ -0,0 +1,110 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from ragna import Rag, assistants, source_storages\n", + "\n", + "rag = Rag()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from ragna.core import filesystem_glob\n", + "\n", + "globs = filesystem_glob(\"github://nenb/Notes/programming/ADR/**.txt\")\n", + "# ugly, but it works for now\n", + "documents = [\"github://nenb/Notes/\" + glob for glob in globs]\n", + "\n", + "chat = rag.chat(\n", + " documents=documents,\n", + " source_storage=source_storages.LanceDB,\n", + " assistant=assistants.Gpt35Turbo16k,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Message(content='How can I help you with the documents?', role=, sources=[])" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await chat.prepare()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SQS was chosen for task queues due to several factors. One of the deciding factors was the simplicity of the client interface, including the built-in support for Dead Letter Queues. Additionally, SQS can be easily replicated across regions if necessary in the future. Another consideration was the cost-effectiveness of using SQS compared to other options like Kafka, especially when replicated across regions. Overall, SQS provided a simple and reliable solution for managing task queues in the system.\n" + ] + } + ], + "source": [ + "print(await chat.answer(\"Why was SQS chosen for task queues?\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ragna-dev", + "language": "python", + "name": "ragna-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/rest_api/fsspec_demo.ipynb b/examples/rest_api/fsspec_demo.ipynb new file mode 100644 index 00000000..936fc4b6 --- /dev/null +++ b/examples/rest_api/fsspec_demo.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import requests\n", + "\n", + "# smoke-test -> have you started the API?\n", + "requests.get(\"http://localhost:31476/docs\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "t = requests.post(\n", + " \"http://localhost:31476/token\", data={\"username\": \"nenb\", \"password\": \"nenb\"}\n", + ").json()\n", + "\n", + "headers = {\"Authorization\": f\"Bearer {t}\"}" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'filesystems': ['github', 'local'],\n", + " 'documents': ['.pdf', '.txt'],\n", + " 'source_storages': [{'properties': {'chunk_overlap': {'default': 250,\n", + " 'title': 'Chunk Overlap',\n", + " 'type': 'integer'},\n", + " 'chunk_size': {'default': 500, 'title': 'Chunk Size', 'type': 'integer'},\n", + " 'num_tokens': {'default': 1024, 'title': 'Num Tokens', 'type': 'integer'}},\n", + " 'required': [],\n", + " 'title': 'Chroma',\n", + " 'type': 'object'},\n", + " {'properties': {'chunk_overlap': {'default': 250,\n", + " 'title': 'Chunk Overlap',\n", + " 'type': 'integer'},\n", + " 'chunk_size': {'default': 500, 'title': 'Chunk Size', 'type': 'integer'},\n", + " 'num_tokens': {'default': 1024, 'title': 'Num Tokens', 'type': 'integer'}},\n", + " 'required': [],\n", + " 'title': 'LanceDB',\n", + " 'type': 'object'}],\n", + " 'assistants': [{'properties': {'max_new_tokens': {'default': 256,\n", + " 'title': 'Max New Tokens',\n", + " 'type': 'integer'}},\n", + " 'title': 'OpenAI/gpt-4',\n", + " 'type': 'object'},\n", + " {'properties': {'max_new_tokens': {'default': 256,\n", + " 'title': 'Max New Tokens',\n", + " 'type': 'integer'}},\n", + " 'title': 'OpenAI/gpt-3.5-turbo-16k',\n", + " 'type': 'object'}]}" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# there is a new filesystems component\n", + "requests.get(\"http://localhost:31476/components\", headers=headers).json()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "name = \"bm25.pdf\"\n", + "prefixed_path = (\n", + " \"github://papers-we-love/papers-we-love/information_retrieval/okapi-at-trec3.pdf\"\n", + ")\n", + "\n", + "# this creates the relevant metadata on the server and stores in the db\n", + "metadata = requests.get(\n", + " f\"http://localhost:31476/document?name={name}&prefixed_path={prefixed_path}\",\n", + " headers=headers,\n", + ").json()\n", + "docs = metadata[\"document\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "data = json.dumps(\n", + " {\n", + " \"name\": \"BM25\",\n", + " \"source_storage\": \"LanceDB\",\n", + " \"assistant\": \"OpenAI/gpt-3.5-turbo-16k\",\n", + " \"params\": {},\n", + " \"documents\": [docs],\n", + " }\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# create a new chat\n", + "chat_metadata = requests.post(\n", + " \"http://localhost:31476/chats\", headers=headers, data=data\n", + ").json()\n", + "chat_id = chat_metadata[\"id\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# this retrieves the pdf from github, and stores the embeddings on the server\n", + "requests.post(f\"http://localhost:31476/chats/{chat_id}/prepare\", headers=headers).json()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The new developments for TREC-3 included the introduction of the simple inverse collection frequency (ICF) term-weighting scheme, which incorporated within-document frequency, document length, and within-query frequency components. Additionally, there were advancements in automatic ad hoc and routing results, as well as the development of a user interface and search framework. Query expansion and routing term selection were also successful developments. Modified term-weighting functions and passage retrieval had small beneficial effects.\n" + ] + } + ], + "source": [ + "prompt = \"What were the new developments for TREC-3?\"\n", + "response = requests.post(\n", + " f\"http://localhost:31476/chats/{chat_id}/answer?prompt={prompt}\", headers=headers\n", + ").json()\n", + "\n", + "print(response[\"message\"][\"content\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ragna-dev", + "language": "python", + "name": "ragna-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From ae6bde8c29cef3a6cc9857c8973bac3d9c9fc005 Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Sat, 16 Dec 2023 20:28:19 +0100 Subject: [PATCH 4/5] Add filesystem abstraction to UI --- ragna/deploy/_ui/api_wrapper.py | 13 +- ragna/deploy/_ui/components/file_uploader.py | 4 +- ragna/deploy/_ui/imgs/github-mark-white.svg | 1 + ragna/deploy/_ui/imgs/github-mark.svg | 1 + ragna/deploy/_ui/modal_configuration.py | 148 ++++++++++++++++++- ragna/deploy/_ui/styles.py | 2 +- 6 files changed, 158 insertions(+), 11 deletions(-) create mode 100644 ragna/deploy/_ui/imgs/github-mark-white.svg create mode 100644 ragna/deploy/_ui/imgs/github-mark.svg diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index 647be9b6..11fd2744 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -85,6 +85,18 @@ async def answer(self, chat_id, prompt): async def get_components(self): return (await self.client.get("/components")).raise_for_status().json() + async def get_document(self, name, prefixed_path=None): + if prefixed_path is None: + params = {"name": name} + else: + params = {"name": name, "prefixed_path": prefixed_path} + + return ( + (await self.client.get("/document", params=params)) + .raise_for_status() + .json() + ) + # Upload and related functions def upload_endpoints(self): return { @@ -113,7 +125,6 @@ async def start_and_prepare( self, name, documents, source_storage, assistant, params={} ): chat = await self.start_chat(name, documents, source_storage, assistant, params) - ( await self.client.post(f"/chats/{chat['id']}/prepare", timeout=None) ).raise_for_status() diff --git a/ragna/deploy/_ui/components/file_uploader.py b/ragna/deploy/_ui/components/file_uploader.py index a265d942..45b782a8 100644 --- a/ragna/deploy/_ui/components/file_uploader.py +++ b/ragna/deploy/_ui/components/file_uploader.py @@ -42,7 +42,9 @@ def update_allowed_documents_str(self): @param.depends("uploaded_documents_json", watch=True) async def did_finish_upload(self): if self.after_upload_callback is not None: - await self.after_upload_callback(json.loads(self.uploaded_documents_json)) + await self.after_upload_callback( + uploaded_documents=json.loads(self.uploaded_documents_json), + ) def perform_upload(self, event=None, after_upload_callback=None): self.after_upload_callback = after_upload_callback diff --git a/ragna/deploy/_ui/imgs/github-mark-white.svg b/ragna/deploy/_ui/imgs/github-mark-white.svg new file mode 100644 index 00000000..c679c236 --- /dev/null +++ b/ragna/deploy/_ui/imgs/github-mark-white.svg @@ -0,0 +1 @@ + diff --git a/ragna/deploy/_ui/imgs/github-mark.svg b/ragna/deploy/_ui/imgs/github-mark.svg new file mode 100644 index 00000000..98d74c33 --- /dev/null +++ b/ragna/deploy/_ui/imgs/github-mark.svg @@ -0,0 +1 @@ + diff --git a/ragna/deploy/_ui/modal_configuration.py b/ragna/deploy/_ui/modal_configuration.py index a0c72fe2..50931e2a 100644 --- a/ragna/deploy/_ui/modal_configuration.py +++ b/ragna/deploy/_ui/modal_configuration.py @@ -1,5 +1,9 @@ +import asyncio +import os +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone +import fsspec import panel as pn import param @@ -18,6 +22,7 @@ def get_default_chat_name(timezone_offset=None): class ChatConfig(param.Parameterized): allowed_documents = param.List(default=["TXT"]) + available_filesystems = param.List(default=["local"]) source_storage_name = param.Selector() assistant_name = param.Selector() @@ -70,6 +75,8 @@ def to_params_dict(self): class ModalConfiguration(pn.viewable.Viewer): chat_name = param.String() + github_repo = param.String(default="Quansight/ragna/docs") + github_file_formats = param.String(default="md,mdx") config = param.ClassSelector(class_=ChatConfig, default=None) new_chat_ready_callback = param.Callable() @@ -84,6 +91,8 @@ def __init__(self, api_wrapper, **params): upload_endpoints = self.api_wrapper.upload_endpoints() + self.github_documents = list() + self.chat_name_input = pn.widgets.TextInput.from_param( self.param.chat_name, stylesheets=[ui.BK_INPUT_GRAY_BORDER], @@ -101,7 +110,9 @@ def __init__(self, api_wrapper, **params): self.cancel_button.on_click(self.cancel_button_callback) self.start_chat_button = pn.widgets.Button( - name="Start Conversation", button_type="primary", min_width=375 + name="Start Conversation", + button_type="primary", + min_width=375, ) self.start_chat_button.on_click(self.did_click_on_start_chat_button) @@ -116,27 +127,108 @@ def __init__(self, api_wrapper, **params): self.got_timezone = False - def did_click_on_start_chat_button(self, event): - if not self.document_uploader.can_proceed_to_upload(): + self.add_repo_button = pn.widgets.Button( + name="Add", button_type="default", height=30 + ) + + self.add_repo_button.on_click(self.add_github_repo_metadata) + + async def add_github_repo_metadata(self, event): + # disable call to 'prepare' endpoint until metadata created + self.start_chat_button.disabled = True + self.start_chat_button.name = "Loading..." + + if self.github_repo: + try: + org, repo, *ks = self.github_repo.split("/") + except ValueError: + print("Need to implement error handling...") + self.start_chat_button.disabled = False + self.start_chat_button.name = "Start Conversation" + return + + key = "/".join(ks) + loop = asyncio.get_event_loop() + # necessary to prevent UI from freezing + with ThreadPoolExecutor() as executor: + filepaths = await loop.run_in_executor( + executor, self.get_filepaths, org, repo, key + ) + + for f in filepaths: + metadata = await self.api_wrapper.get_document( + name=os.path.basename(f), + prefixed_path=f"github://{org}/{repo}/{f}", + ) + self.github_documents.append(metadata["document"]) + + self.start_chat_button.disabled = False + self.start_chat_button.name = "Start Conversation" + + def get_filepaths(self, org, repo, key): + fs = fsspec.filesystem( + "github", + org=org, + repo=repo, + username=os.environ["GITHUB_USERNAME"], + token=os.environ["GITHUB_TOKEN"], + ) + + if fs.isfile(key): + filepaths = [key] + elif fs.isdir(key): + # Expected format is comma separated list of file extensions + file_format_string = self.github_file_formats.replace(" ", "") + if key: + pattern = f"{key}/**" + glob = fs.glob(pattern) + else: # root of repo + pattern = "**" + glob = fs.glob(pattern) + filepaths = [ + g for g in glob if g.split(".")[-1] in file_format_string.split(",") + ] + else: + print("Need to implement error handling...") + # TODO: remove this return statement after error handling is implemented + filepaths = [] + + return filepaths + + async def did_click_on_start_chat_button(self, event): + if ( + not self.document_uploader.can_proceed_to_upload() + and not self.github_documents + ): self.change_upload_files_label("missing_file") else: self.start_chat_button.disabled = True - self.document_uploader.perform_upload(event, self.did_finish_upload) + self.add_repo_button.disabled = True + self.document_uploader.perform_upload( + event, + self.did_finish_upload, + ) async def did_finish_upload(self, uploaded_documents): # at this point, the UI has uploaded the files to the API. # We can now start the chat + if self.github_documents is not None: + documents = uploaded_documents + self.github_documents + else: + documents = uploaded_documents + try: new_chat_id = await self.api_wrapper.start_and_prepare( name=self.chat_name, - documents=uploaded_documents, + documents=documents, source_storage=self.config.source_storage_name, assistant=self.config.assistant_name, params=self.config.to_params_dict(), ) self.start_chat_button.disabled = False + self.add_repo_button.disabled = False if self.new_chat_ready_callback is not None: await self.new_chat_ready_callback(new_chat_id) @@ -145,16 +237,17 @@ async def did_finish_upload(self, uploaded_documents): self.change_upload_files_label("upload_error") self.document_uploader.loading = False self.start_chat_button.disabled = False + self.add_repo_button.disabled = False def change_upload_files_label(self, mode="normal"): if mode == "upload_error": - self.upload_files_label.object = "Upload files (required)An error occured. Please try again or contact your administrator." + self.upload_files_label.object = "Upload files An error occured. Please try again or contact your administrator." elif mode == "missing_file": self.upload_files_label.object = ( - "Upload files (required)" + "Upload files" ) else: - self.upload_files_label.object = "Upload files (required)" + self.upload_files_label.object = "Upload files" async def model_section(self): # prevents re-rendering the section @@ -167,6 +260,7 @@ async def model_section(self): config.allowed_documents = [ ext[1:].upper() for ext in components["documents"] ] + config.available_filesystems = [fs for fs in components["filesystems"]] assistants = [component["title"] for component in components["assistants"]] @@ -317,6 +411,43 @@ def toggle_card(event): return pn.Column(toggle_button, card) + @pn.depends("config", "config.available_filesystems") + async def github_section(self): + if self.config is None: + return + + if "github" not in self.config.available_filesystems: + return + else: + return pn.Row( + pn.Column( + pn.pane.HTML( + " GitHub repository", + height=30, + ), + pn.widgets.TextInput.from_param( + self.param.github_repo, + name="", + stylesheets=[ui.BK_INPUT_GRAY_BORDER], + ), + ), + pn.Column( + pn.pane.HTML( + " File formats", + height=30, + ), + pn.widgets.TextInput.from_param( + self.param.github_file_formats, + name="", + stylesheets=[ui.BK_INPUT_GRAY_BORDER], + ), + ), + pn.Column( + pn.pane.HTML("", height=30), + self.add_repo_button, + ), + ) + def __panel__(self): return pn.Column( pn.pane.HTML( @@ -333,6 +464,7 @@ def __panel__(self): ui.divider(), self.advanced_config_ui, ui.divider(), + self.github_section, self.upload_files_label, self.upload_row, pn.Row(self.cancel_button, self.start_chat_button), diff --git a/ragna/deploy/_ui/styles.py b/ragna/deploy/_ui/styles.py index a3a73f7a..04a6aac9 100644 --- a/ragna/deploy/_ui/styles.py +++ b/ragna/deploy/_ui/styles.py @@ -21,7 +21,7 @@ def divider(): # TABS_SIDEBAR_WIDTH = "20em" # set modal height -CONFIG_MODAL_MIN_HEIGHT = 610 +CONFIG_MODAL_MIN_HEIGHT = 750 CONFIG_MODAL_MAX_HEIGHT = 850 CONFIG_MODAL_WIDTH = 800 From d4a26eeca4474dfc38d2e2bed331522405f9ce43 Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Sat, 16 Dec 2023 20:39:42 +0100 Subject: [PATCH 5/5] Add hack for parsing Markdown for UI PoC --- ragna/core/_document.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 9a6f7288..18b0b04c 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -390,7 +390,7 @@ class TxtDocumentHandler(DocumentHandler): @classmethod def supported_suffixes(cls) -> list[str]: - return [".txt"] + return [".txt", ".md"] def extract_pages(self, document: Document) -> Iterator[Page]: yield Page(text=document.read().decode())