diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 4868b270..0cbae6d2 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -2,6 +2,7 @@ import abc import io +import mimetypes import uuid from functools import cached_property from pathlib import Path @@ -25,11 +26,15 @@ def __init__( name: str, metadata: dict[str, Any], handler: Optional[DocumentHandler] = None, + mime_type: str | None = None, ): self.id = id or uuid.uuid4() self.name = name self.metadata = metadata self.handler = handler or self.get_handler(name) + self.mime_type = ( + mime_type or mimetypes.guess_type(name)[0] or "application/octet-stream" + ) @staticmethod def supported_suffixes() -> set[str]: @@ -76,8 +81,11 @@ def __init__( name: str, metadata: dict[str, Any], handler: Optional[DocumentHandler] = None, + mime_type: str | None = None, ): - super().__init__(id=id, name=name, metadata=metadata, handler=handler) + super().__init__( + id=id, name=name, metadata=metadata, handler=handler, mime_type=mime_type + ) if "path" not in self.metadata: metadata["path"] = str(ragna.local_root() / "documents" / str(self.id)) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index ff320dae..ae4f2004 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -1,3 +1,4 @@ +import io import uuid from typing import Annotated, Any, AsyncIterator @@ -40,6 +41,28 @@ async def content_stream() -> AsyncIterator[bytes]: ], ) + @router.get("/documents") + async def get_documents(user: UserDependency) -> list[schemas.Document]: + return engine.get_documents(user=user.name) + + @router.get("/documents/{id}") + async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document: + return engine.get_document(user=user.name, id=id) + + @router.get("/documents/{id}/content") + async def get_document_content( + user: UserDependency, id: uuid.UUID + ) -> StreamingResponse: + schema_document = engine.get_document(user=user.name, id=id) + core_document = engine._to_core.document(schema_document) + headers = {"Content-Disposition": f"inline; filename={schema_document.name}"} + + return StreamingResponse( + io.BytesIO(core_document.read()), + media_type=core_document.mime_type, + headers=headers, + ) + @router.get("/components") def get_components() -> schemas.Components: return engine.get_components() diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index e2390104..945762f3 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -132,16 +132,16 @@ def add_documents( session.commit() def _get_orm_documents( - self, session: Session, *, user: str, ids: Collection[uuid.UUID] + self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None ) -> list[orm.Document]: # FIXME also check if the user is allowed to access the documents # FIXME: maybe just take the user id to avoid getting it twice in add_chat? - documents = ( - session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) - .scalars() - .all() - ) - if len(documents) != len(ids): + expr = select(orm.Document) + if ids is not None: + expr = expr.where(orm.Document.id.in_(ids)) + documents = session.execute(expr).scalars().all() + + if (ids is not None) and (len(documents) != len(ids)): raise RagnaException( str(set(ids) - {document.id for document in documents}) ) @@ -149,7 +149,7 @@ def _get_orm_documents( return documents # type: ignore[no-any-return] def get_documents( - self, session: Session, *, user: str, ids: Collection[uuid.UUID] + self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None ) -> list[schemas.Document]: return [ self._to_schema.document(document) @@ -288,6 +288,7 @@ def document( user_id=user_id, name=document.name, metadata_=document.metadata, + mime_type=document.mime_type, ) def source(self, source: schemas.Source) -> orm.Source: @@ -354,7 +355,10 @@ def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey: def document(self, document: orm.Document) -> schemas.Document: return schemas.Document( - id=document.id, name=document.name, metadata=document.metadata_ + id=document.id, + name=document.name, + metadata=document.metadata_, + mime_type=document.mime_type, ) def source(self, source: orm.Source) -> schemas.Source: diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 6694d32c..bc0b1afe 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,6 +1,6 @@ import secrets import uuid -from typing import Any, AsyncIterator, Optional, cast +from typing import Any, AsyncIterator, Collection, Optional, cast from fastapi import status as http_status_code @@ -156,7 +156,9 @@ def register_documents( # We create core.Document's first, because they might update the metadata core_documents = [ self._config.document( - name=registration.name, metadata=registration.metadata + name=registration.name, + metadata=registration.metadata, + mime_type=registration.mime_type, ) for registration in document_registrations ] @@ -182,10 +184,7 @@ async def store_documents( streams = dict(ids_and_streams) - with self._database.get_session() as session: - documents = self._database.get_documents( - session, user=user, ids=streams.keys() - ) + documents = self.get_documents(user=user, ids=streams.keys()) for document in documents: core_document = cast( @@ -193,6 +192,15 @@ async def store_documents( ) await core_document._write(streams[document.id]) + def get_documents( + self, *, user: str, ids: Collection[uuid.UUID] | None = None + ) -> list[schemas.Document]: + with self._database.get_session() as session: + return self._database.get_documents(session, user=user, ids=ids) + + def get_document(self, *, user: str, id: uuid.UUID) -> schemas.Document: + return self.get_documents(user=user, ids=[id])[0] + def create_chat( self, *, user: str, chat_creation: schemas.ChatCreation ) -> schemas.Chat: @@ -280,6 +288,7 @@ def document(self, document: schemas.Document) -> core.Document: id=document.id, name=document.name, metadata=document.metadata, + mime_type=document.mime_type, ) def source(self, source: schemas.Source) -> core.Source: @@ -328,6 +337,7 @@ def document(self, document: core.Document) -> schemas.Document: id=document.id, name=document.name, metadata=document.metadata, + mime_type=document.mime_type, ) def source(self, source: core.Source) -> schemas.Source: diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index a5660db4..e1d8921f 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -103,6 +103,7 @@ class Document(Base): # Mind the trailing underscore here. Unfortunately, this is necessary, because # metadata without the underscore is reserved by SQLAlchemy metadata_ = Column(Json, nullable=False) + mime_type = Column(types.String, nullable=False) chats = relationship( "Chat", secondary=document_chat_association_table, diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 6bbfea63..2d9c50bd 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -78,12 +78,14 @@ class Components(BaseModel): class DocumentRegistration(BaseModel): name: str metadata: dict[str, Any] = Field(default_factory=dict) + mime_type: str | None = None class Document(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) name: str metadata: dict[str, Any] + mime_type: str class Source(BaseModel): diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 973c2c49..684a1a22 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -4,6 +4,7 @@ from ragna import assistants from ragna.core import RagnaException from ragna.deploy import Config +from tests.deploy.api.utils import upload_documents from tests.deploy.utils import make_api_app, make_api_client @@ -56,17 +57,8 @@ def test_unknown_component(tmp_local_root): with open(document_path, "w") as file: file.write("!\n") - with make_api_client( - config=Config(), ignore_unavailable_components=False - ) as client: - document = ( - client.post("/api/documents", json=[{"name": document_path.name}]) - .raise_for_status() - .json()[0] - ) - - with open(document_path, "rb") as file: - client.put("/api/documents", files={"documents": (document["id"], file)}) + with make_api_client(config=config, ignore_unavailable_components=False) as client: + document = upload_documents(client=client, document_paths=[document_path])[0] response = client.post( "/api/chats", @@ -80,7 +72,7 @@ def test_unknown_component(tmp_local_root): }, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - error = response.json()["error"] - assert "Unknown component" in error["message"] + error = response.json()["error"] + assert "Unknown component" in error["message"] diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py new file mode 100644 index 00000000..5f40e3e8 --- /dev/null +++ b/tests/deploy/api/test_endpoints.py @@ -0,0 +1,109 @@ +import mimetypes + +import pytest + +from ragna.deploy import Config +from tests.deploy.api.utils import upload_documents +from tests.deploy.utils import make_api_client + +_document_content_text = [ + f"Needs more {needs_more_of}\n" for needs_more_of in ["reverb", "cowbell"] +] + + +mime_types = pytest.mark.parametrize( + ("mime_type",), + [ + (None,), # Let the mimetypes library decide + ("text/markdown",), + ("application/pdf",), + ], +) + + +@mime_types +def test_get_documents(tmp_local_root, mime_type): + config = Config(local_root=tmp_local_root) + + document_root = config.local_root / "documents" + document_root.mkdir() + document_paths = [ + document_root / f"test{idx}.txt" for idx in range(len(_document_content_text)) + ] + for content, document_path in zip(_document_content_text, document_paths): + with open(document_path, "w") as file: + file.write(content) + + with make_api_client(config=config, ignore_unavailable_components=False) as client: + documents = upload_documents( + client=client, + document_paths=document_paths, + mime_types=[mime_type for _ in document_paths], + ) + response = client.get("/api/documents").raise_for_status() + + # Sort the items in case they are retrieved in different orders + def sorting_key(d): + return d["id"] + + assert sorted(documents, key=sorting_key) == sorted( + response.json(), key=sorting_key + ) + + +@mime_types +def test_get_document(tmp_local_root, mime_type): + config = Config(local_root=tmp_local_root) + + document_root = config.local_root / "documents" + document_root.mkdir() + document_path = document_root / "test.txt" + with open(document_path, "w") as file: + file.write(_document_content_text[0]) + + with make_api_client(config=config, ignore_unavailable_components=False) as client: + document = upload_documents( + client=client, + document_paths=[document_path], + mime_types=[mime_type], + )[0] + response = client.get(f"/api/documents/{document['id']}").raise_for_status() + + assert document == response.json() + + +@mime_types +def test_get_document_content(tmp_local_root, mime_type): + config = Config(local_root=tmp_local_root) + + document_root = config.local_root / "documents" + document_root.mkdir() + document_path = document_root / "test.txt" + document_content = _document_content_text[0] + with open(document_path, "w") as file: + file.write(document_content) + + with make_api_client(config=config, ignore_unavailable_components=False) as client: + document = upload_documents( + client=client, + document_paths=[document_path], + mime_types=[mime_type], + )[0] + + with client.stream( + "GET", f"/api/documents/{document['id']}/content" + ) as response: + response_mime_type = response.headers["content-type"].split(";")[0] + received_lines = list(response.iter_lines()) + + assert received_lines == [document_content.replace("\n", "")] + + assert ( + document["mime_type"] + == response_mime_type + == ( + mime_type + if mime_type is not None + else mimetypes.guess_type(document_path.name)[0] + ) + ) diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py new file mode 100644 index 00000000..04352951 --- /dev/null +++ b/tests/deploy/api/utils.py @@ -0,0 +1,37 @@ +import contextlib + + +def upload_documents(*, client, document_paths, mime_types=None): + if mime_types is None: + mime_types = [None for _ in document_paths] + else: + assert len(mime_types) == len(document_paths) + documents = ( + client.post( + "/api/documents", + json=[ + { + "name": document_path.name, + "mime_type": mime_type, + } + for document_path, mime_type in zip(document_paths, mime_types) + ], + ) + .raise_for_status() + .json() + ) + + with contextlib.ExitStack() as stack: + files = [ + stack.enter_context(open(document_path, "rb")) + for document_path in document_paths + ] + client.put( + "/api/documents", + files=[ + ("documents", (document["id"], file)) + for document, file in zip(documents, files) + ], + ) + + return documents