From a9ebd9ca002b693e88b691d77c6fa00c9d6a7931 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sun, 26 Jan 2025 20:55:30 -0800 Subject: [PATCH 01/49] Give option for database to return all documents --- ragna/deploy/_database.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index e2390104..0830f8a2 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -132,16 +132,20 @@ def add_documents( session.commit() def _get_orm_documents( - self, session: Session, *, user: str, ids: Collection[uuid.UUID] + self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None ) -> list[orm.Document]: # FIXME also check if the user is allowed to access the documents # FIXME: maybe just take the user id to avoid getting it twice in add_chat? documents = ( - session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) - .scalars() - .all() + ( + session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) + .scalars() + .all() + ) + if ids is not None + else session.execute(select(orm.Document)).scalars().all() ) - if len(documents) != len(ids): + if (ids is not None) and (len(documents) != len(ids)): raise RagnaException( str(set(ids) - {document.id for document in documents}) ) @@ -149,7 +153,7 @@ def _get_orm_documents( return documents # type: ignore[no-any-return] def get_documents( - self, session: Session, *, user: str, ids: Collection[uuid.UUID] + self, session: Session, *, user: str, ids: Collection[uuid.UUID] | None = None ) -> list[schemas.Document]: return [ self._to_schema.document(document) From ce37b0cd9e87a33e233b96ad554ef66b7d493eb6 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sun, 26 Jan 2025 21:36:51 -0800 Subject: [PATCH 02/49] Add `get_documents` --- ragna/deploy/_engine.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 6694d32c..e1185099 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -1,6 +1,6 @@ import secrets import uuid -from typing import Any, AsyncIterator, Optional, cast +from typing import Any, AsyncIterator, Collection, Optional, cast from fastapi import status as http_status_code @@ -182,10 +182,7 @@ async def store_documents( streams = dict(ids_and_streams) - with self._database.get_session() as session: - documents = self._database.get_documents( - session, user=user, ids=streams.keys() - ) + documents = self.get_documents(user, streams.keys()) for document in documents: core_document = cast( @@ -193,6 +190,12 @@ async def store_documents( ) await core_document._write(streams[document.id]) + def get_documents(self, user: str, ids: Collection[uuid.UUID] | None = None): + with self._database.get_session() as session: + documents = self._database.get_documents(session, user=user, ids=ids) + + return documents + def create_chat( self, *, user: str, chat_creation: schemas.ChatCreation ) -> schemas.Chat: From b970b96dfc7e77c77284131dc13d8f8f57ae68ae Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:13:00 -0800 Subject: [PATCH 03/49] Add `GET` endpoint for `/documents` --- ragna/deploy/_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index ff320dae..760357d2 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -40,6 +40,10 @@ async def content_stream() -> AsyncIterator[bytes]: ], ) + @router.get("/documents") + async def get_documents(user: UserDependency) -> list[schemas.Document]: + return engine.get_documents(user.name) + @router.get("/components") def get_components() -> schemas.Components: return engine.get_components() From f33e8ea511bad04c46cc4d7853fbd62c332477b4 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:27:52 -0800 Subject: [PATCH 04/49] Add `GET` endpoint for a specific document `/documents/{id}` --- ragna/deploy/_api.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 760357d2..661bac20 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -44,6 +44,10 @@ async def content_stream() -> AsyncIterator[bytes]: async def get_documents(user: UserDependency) -> list[schemas.Document]: return engine.get_documents(user.name) + @router.get("/documents/{id}") + async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document: + return next(iter(engine.get_documents(user.name, [id]))) + @router.get("/components") def get_components() -> schemas.Components: return engine.get_components() From c68f874c55594c64f3d581632b7407adeb798168 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:47:11 -0800 Subject: [PATCH 05/49] Fix mypy error --- ragna/deploy/_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index e1185099..2a1a0403 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -190,7 +190,9 @@ async def store_documents( ) await core_document._write(streams[document.id]) - def get_documents(self, user: str, ids: Collection[uuid.UUID] | None = None): + def get_documents( + self, user: str, ids: Collection[uuid.UUID] | None = None + ) -> list[schemas.Document]: with self._database.get_session() as session: documents = self._database.get_documents(session, user=user, ids=ids) From e7bf467302644395bbdf5382460d10595d428030 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 19:25:42 -0800 Subject: [PATCH 06/49] Add `get_document` to engine for convenience --- ragna/deploy/_api.py | 2 +- ragna/deploy/_engine.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 661bac20..73c1b91d 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -46,7 +46,7 @@ async def get_documents(user: UserDependency) -> list[schemas.Document]: @router.get("/documents/{id}") async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document: - return next(iter(engine.get_documents(user.name, [id]))) + return engine.get_document(user.name, id) @router.get("/components") def get_components() -> schemas.Components: diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 2a1a0403..6607125d 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -198,6 +198,9 @@ def get_documents( return documents + def get_document(self, user: str, id: uuid.UUID) -> schemas.Document: + return next(iter(self.get_documents(user, [id]))) + def create_chat( self, *, user: str, chat_creation: schemas.ChatCreation ) -> schemas.Chat: From ce6a4f20b310b603e3c6fbfd054817974168aefb Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 19:34:50 -0800 Subject: [PATCH 07/49] Clean up --- ragna/deploy/_database.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 0830f8a2..60cc8f10 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -137,13 +137,13 @@ def _get_orm_documents( # FIXME also check if the user is allowed to access the documents # FIXME: maybe just take the user id to avoid getting it twice in add_chat? documents = ( - ( - session.execute(select(orm.Document).where(orm.Document.id.in_(ids))) - .scalars() - .all() + session.execute( + select(orm.Document).where(orm.Document.id.in_(ids)) + if ids is not None + else select(orm.Document) ) - if ids is not None - else session.execute(select(orm.Document)).scalars().all() + .scalars() + .all() ) if (ids is not None) and (len(documents) != len(ids)): raise RagnaException( From c4bbcaf615c3fc176258a057a53e7c27000476ae Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:31:22 -0800 Subject: [PATCH 08/49] Add support for MIME types in `core.Document`s --- ragna/core/_document.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 4868b270..a2b8902b 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -14,6 +14,14 @@ from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin +_MIME_TYPES = { + ".pdf": "application/pdf", + ".txt": "text/plain", + ".md": "text/plain", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", +} + class Document(RequirementsMixin, abc.ABC): """Abstract base class for all documents.""" @@ -25,11 +33,13 @@ def __init__( name: str, metadata: dict[str, Any], handler: Optional[DocumentHandler] = None, + mime_type: str | None = None, ): self.id = id or uuid.uuid4() self.name = name self.metadata = metadata self.handler = handler or self.get_handler(name) + self.mime_type = mime_type or self.parse_mime_type(name) @staticmethod def supported_suffixes() -> set[str]: @@ -59,6 +69,16 @@ def read(self) -> bytes: ... def extract_pages(self) -> Iterator[Page]: yield from self.handler.extract_pages(self) + @staticmethod + def parse_mime_type(name: str) -> str: + """Parse file MIME-type from file name suffix. + + Args: + name: Name of the document. + """ + + return _MIME_TYPES.get(Path(name).suffix, "application/octet-stream") + class LocalDocument(Document): """Document class for files on the local file system. From 1ab6f581c1915abcd142719200f85855fe4c0b9b Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:34:47 -0800 Subject: [PATCH 09/49] Add `GET` `/documents/{id}/content` endpoint --- ragna/deploy/_api.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 73c1b91d..73316611 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -1,3 +1,4 @@ +import io import uuid from typing import Annotated, Any, AsyncIterator @@ -48,6 +49,20 @@ async def get_documents(user: UserDependency) -> list[schemas.Document]: async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document: return engine.get_document(user.name, id) + @router.get("/documents/{id}/content") + async def get_document_content( + user: UserDependency, id: uuid.UUID + ) -> StreamingResponse: + schema_document = engine.get_document(user.name, id) + core_document = engine._to_core(schema_document) + headers = {"Content-Disposition": f"inline; filename={schema_document.name}"} + + return StreamingResponse( + io.BytesIO(core_document.read()), + media_type=core_document.mime_type, + headers=headers, + ) + @router.get("/components") def get_components() -> schemas.Components: return engine.get_components() From e57950c1946adac831114e620fd06b9bd290f7f5 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:02:07 -0800 Subject: [PATCH 10/49] Call correct method --- ragna/deploy/_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 73316611..07748f7e 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -54,7 +54,7 @@ async def get_document_content( user: UserDependency, id: uuid.UUID ) -> StreamingResponse: schema_document = engine.get_document(user.name, id) - core_document = engine._to_core(schema_document) + core_document = engine._to_core.document(schema_document) headers = {"Content-Disposition": f"inline; filename={schema_document.name}"} return StreamingResponse( From 1bc57b03fddca7d36a8b010c7e32d80a0e9ea1d4 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:20:22 -0800 Subject: [PATCH 11/49] Use the builtin `mimetypes` library instead of custom logic --- ragna/core/_document.py | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index a2b8902b..b2810173 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -2,6 +2,7 @@ import abc import io +import mimetypes import uuid from functools import cached_property from pathlib import Path @@ -14,14 +15,6 @@ from ._utils import PackageRequirement, RagnaException, Requirement, RequirementsMixin -_MIME_TYPES = { - ".pdf": "application/pdf", - ".txt": "text/plain", - ".md": "text/plain", - ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", -} - class Document(RequirementsMixin, abc.ABC): """Abstract base class for all documents.""" @@ -39,7 +32,11 @@ def __init__( self.name = name self.metadata = metadata self.handler = handler or self.get_handler(name) - self.mime_type = mime_type or self.parse_mime_type(name) + self.mime_type = ( + mime_type + or next(iter(mimetypes.guess_type(Path(name)))) + or "application/octet-stream" + ) @staticmethod def supported_suffixes() -> set[str]: @@ -69,16 +66,6 @@ def read(self) -> bytes: ... def extract_pages(self) -> Iterator[Page]: yield from self.handler.extract_pages(self) - @staticmethod - def parse_mime_type(name: str) -> str: - """Parse file MIME-type from file name suffix. - - Args: - name: Name of the document. - """ - - return _MIME_TYPES.get(Path(name).suffix, "application/octet-stream") - class LocalDocument(Document): """Document class for files on the local file system. From f67997e020b50075d6816f152d75f6aab5c818bb Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 01:53:18 -0800 Subject: [PATCH 12/49] Add mime_type to `Document` schema --- ragna/deploy/_schemas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 6bbfea63..2c1072fe 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -84,6 +84,7 @@ class Document(BaseModel): id: uuid.UUID = Field(default_factory=uuid.uuid4) name: str metadata: dict[str, Any] + mime_type: str class Source(BaseModel): From 23898f005699b3826410f94f89466c40c638f097 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 01:58:18 -0800 Subject: [PATCH 13/49] Add MIME type to `Document` ORM object --- ragna/deploy/_orm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ragna/deploy/_orm.py b/ragna/deploy/_orm.py index a5660db4..e1d8921f 100644 --- a/ragna/deploy/_orm.py +++ b/ragna/deploy/_orm.py @@ -103,6 +103,7 @@ class Document(Base): # Mind the trailing underscore here. Unfortunately, this is necessary, because # metadata without the underscore is reserved by SQLAlchemy metadata_ = Column(Json, nullable=False) + mime_type = Column(types.String, nullable=False) chats = relationship( "Chat", secondary=document_chat_association_table, From e53b18088b865174a38d8c3762272a762ad2573e Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 01:59:53 -0800 Subject: [PATCH 14/49] Add MIME type to ORM <> Schema converters and Core <> Schema converters --- ragna/deploy/_database.py | 6 +++++- ragna/deploy/_engine.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 60cc8f10..46484b7a 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -292,6 +292,7 @@ def document( user_id=user_id, name=document.name, metadata_=document.metadata, + mime_type=document.mime_type, ) def source(self, source: schemas.Source) -> orm.Source: @@ -358,7 +359,10 @@ def api_key(self, api_key: orm.ApiKey) -> schemas.ApiKey: def document(self, document: orm.Document) -> schemas.Document: return schemas.Document( - id=document.id, name=document.name, metadata=document.metadata_ + id=document.id, + name=document.name, + metadata=document.metadata_, + mime_type=document.mime_type, ) def source(self, source: orm.Source) -> schemas.Source: diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 6607125d..67fdebe7 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -288,6 +288,7 @@ def document(self, document: schemas.Document) -> core.Document: id=document.id, name=document.name, metadata=document.metadata, + mime_type=document.mime_type, ) def source(self, source: schemas.Source) -> core.Source: @@ -336,6 +337,7 @@ def document(self, document: core.Document) -> schemas.Document: id=document.id, name=document.name, metadata=document.metadata, + mime_type=document.mime_type, ) def source(self, source: core.Source) -> schemas.Source: From 2cfa02a41769780f7289638a3fbb034e6eac1dbb Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 02:14:37 -0800 Subject: [PATCH 15/49] Add `mime_type` to initializer for `LocalDocument` --- ragna/core/_document.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index b2810173..5e4a787f 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -83,8 +83,11 @@ def __init__( name: str, metadata: dict[str, Any], handler: Optional[DocumentHandler] = None, + mime_type: str | None = None, ): - super().__init__(id=id, name=name, metadata=metadata, handler=handler) + super().__init__( + id=id, name=name, metadata=metadata, handler=handler, mime_type=mime_type + ) if "path" not in self.metadata: metadata["path"] = str(ragna.local_root() / "documents" / str(self.id)) From 32dc1ab738da7c3ee0a442375876fb27e1c08b52 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:03:13 -0800 Subject: [PATCH 16/49] Remove unnecessary type conversion --- ragna/core/_document.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 5e4a787f..904f031d 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -34,7 +34,7 @@ def __init__( self.handler = handler or self.get_handler(name) self.mime_type = ( mime_type - or next(iter(mimetypes.guess_type(Path(name)))) + or next(iter(mimetypes.guess_type(name))) or "application/octet-stream" ) From 32e11ba32b7073418b645e7e305cafcc9147307e Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:05:38 -0800 Subject: [PATCH 17/49] Make code more concise --- ragna/deploy/_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 67fdebe7..54002702 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -194,9 +194,7 @@ def get_documents( self, user: str, ids: Collection[uuid.UUID] | None = None ) -> list[schemas.Document]: with self._database.get_session() as session: - documents = self._database.get_documents(session, user=user, ids=ids) - - return documents + return self._database.get_documents(session, user=user, ids=ids) def get_document(self, user: str, id: uuid.UUID) -> schemas.Document: return next(iter(self.get_documents(user, [id]))) From 162a3ff53ef0f1616bd465ec43fbbe3e6e0bbc19 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:13:01 -0800 Subject: [PATCH 18/49] Enforce keyword arguments --- ragna/deploy/_api.py | 6 +++--- ragna/deploy/_engine.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ragna/deploy/_api.py b/ragna/deploy/_api.py index 07748f7e..ae4f2004 100644 --- a/ragna/deploy/_api.py +++ b/ragna/deploy/_api.py @@ -43,17 +43,17 @@ async def content_stream() -> AsyncIterator[bytes]: @router.get("/documents") async def get_documents(user: UserDependency) -> list[schemas.Document]: - return engine.get_documents(user.name) + return engine.get_documents(user=user.name) @router.get("/documents/{id}") async def get_document(user: UserDependency, id: uuid.UUID) -> schemas.Document: - return engine.get_document(user.name, id) + return engine.get_document(user=user.name, id=id) @router.get("/documents/{id}/content") async def get_document_content( user: UserDependency, id: uuid.UUID ) -> StreamingResponse: - schema_document = engine.get_document(user.name, id) + schema_document = engine.get_document(user=user.name, id=id) core_document = engine._to_core.document(schema_document) headers = {"Content-Disposition": f"inline; filename={schema_document.name}"} diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 54002702..0b97153d 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -182,7 +182,7 @@ async def store_documents( streams = dict(ids_and_streams) - documents = self.get_documents(user, streams.keys()) + documents = self.get_documents(user=user, ids=streams.keys()) for document in documents: core_document = cast( @@ -191,13 +191,13 @@ async def store_documents( await core_document._write(streams[document.id]) def get_documents( - self, user: str, ids: Collection[uuid.UUID] | None = None + self, *, user: str, ids: Collection[uuid.UUID] | None = None ) -> list[schemas.Document]: with self._database.get_session() as session: return self._database.get_documents(session, user=user, ids=ids) - def get_document(self, user: str, id: uuid.UUID) -> schemas.Document: - return next(iter(self.get_documents(user, [id]))) + def get_document(self, *, user: str, id: uuid.UUID) -> schemas.Document: + return next(iter(self.get_documents(user=user, ids=[id]))) def create_chat( self, *, user: str, chat_creation: schemas.ChatCreation From e122713e246fda3744573068eb5286f558deae2b Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 28 Jan 2025 18:31:12 -0800 Subject: [PATCH 19/49] Help expression scale --- ragna/deploy/_database.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 46484b7a..3c9b22ce 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -136,15 +136,11 @@ def _get_orm_documents( ) -> list[orm.Document]: # FIXME also check if the user is allowed to access the documents # FIXME: maybe just take the user id to avoid getting it twice in add_chat? - documents = ( - session.execute( - select(orm.Document).where(orm.Document.id.in_(ids)) - if ids is not None - else select(orm.Document) - ) - .scalars() - .all() - ) + expr = select(orm.Document) + expr = expr if ids is None else expr.where(orm.Document.id.in_(ids)) + + documents = session.execute(expr).scalars().all() + if (ids is not None) and (len(documents) != len(ids)): raise RagnaException( str(set(ids) - {document.id for document in documents}) From 06b5ab7881b47aafa99b1488eed4b5eefddf6e4e Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Wed, 29 Jan 2025 00:36:32 -0800 Subject: [PATCH 20/49] Use `__getitem__` instead of `next(iter(...))` --- ragna/core/_document.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ragna/core/_document.py b/ragna/core/_document.py index 904f031d..0cbae6d2 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -33,9 +33,7 @@ def __init__( self.metadata = metadata self.handler = handler or self.get_handler(name) self.mime_type = ( - mime_type - or next(iter(mimetypes.guess_type(name))) - or "application/octet-stream" + mime_type or mimetypes.guess_type(name)[0] or "application/octet-stream" ) @staticmethod From a9bf0fdbe268cf3de5673458dda85d3e07117c5f Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Wed, 29 Jan 2025 00:37:20 -0800 Subject: [PATCH 21/49] Use traditional `if` statement instead of ternary operator --- ragna/deploy/_database.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ragna/deploy/_database.py b/ragna/deploy/_database.py index 3c9b22ce..945762f3 100644 --- a/ragna/deploy/_database.py +++ b/ragna/deploy/_database.py @@ -137,8 +137,8 @@ def _get_orm_documents( # FIXME also check if the user is allowed to access the documents # FIXME: maybe just take the user id to avoid getting it twice in add_chat? expr = select(orm.Document) - expr = expr if ids is None else expr.where(orm.Document.id.in_(ids)) - + if ids is not None: + expr = expr.where(orm.Document.id.in_(ids)) documents = session.execute(expr).scalars().all() if (ids is not None) and (len(documents) != len(ids)): From 94ca3f9c5393ae97bda7e78c5a28730b8612c013 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Wed, 29 Jan 2025 21:32:29 -0800 Subject: [PATCH 22/49] Add empty `test_endpoints.py` file --- tests/deploy/api/test_endpoints.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/deploy/api/test_endpoints.py diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py new file mode 100644 index 00000000..e69de29b From fffd104eff72baf73a30ab324b8189ae0476df61 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Wed, 29 Jan 2025 21:43:55 -0800 Subject: [PATCH 23/49] Prevent naming collisions --- tests/deploy/api/test_components.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 973c2c49..79e32092 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -51,8 +51,8 @@ def test_unknown_component(tmp_local_root): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" - document_root.mkdir() - document_path = document_root / "test.txt" + document_root.mkdir(exist_ok=True) + document_path = document_root / "test_unknown_component.txt" with open(document_path, "w") as file: file.write("!\n") From e1990542649bcce956fb77a1cbe2391cdf59342c Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 30 Jan 2025 20:11:15 -0800 Subject: [PATCH 24/49] Add test for `GET documents` endpoint --- tests/deploy/api/test_endpoints.py | 58 ++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index e69de29b..b036a47c 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -0,0 +1,58 @@ +import contextlib + +from ragna.deploy import Config +from tests.deploy.utils import make_api_client + + +def test_get_documents(tmp_local_root): + config = Config(local_root=tmp_local_root) + + needs_more_of = ["reverb", "cowbell"] + + document_root = config.local_root / "documents" + document_root.mkdir(exist_ok=True) + document_paths = [ + document_root / f"test_get_documents_{what_it_needs}.txt" + for what_it_needs in needs_more_of + ] + for what_it_needs, document_path in zip(needs_more_of, document_paths): + with open(document_path, "w") as file: + file.write(f"Needs more {what_it_needs}\n") + + with make_api_client( + config=Config(), ignore_unavailable_components=False + ) as client: + documents = ( + client.post( + "/api/documents", + json=[{"name": document_path.name} for document_path in document_paths], + ) + .raise_for_status() + .json() + ) + + with contextlib.ExitStack() as stack: + files = [ + stack.enter_context(open(document_path, "rb")) + for document_path in document_paths + ] + client.put( + "/api/documents", + files={ + "documents": [ + (document["id"], file) + for document, file in zip(documents, files) + ] + }, + ) + + response = client.get("/api/documents") + response.raise_for_status() + + # Sort the items in case they are retrieved in different orders + def _sorting_key(d): + return d["id"] + + assert sorted(documents, key=_sorting_key) == sorted( + response.json(), key=_sorting_key + ) From 773a0308eb34a61e08c3abe1e316506d6615f7bc Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Thu, 30 Jan 2025 20:44:10 -0800 Subject: [PATCH 25/49] Add test for `GET document` --- tests/deploy/api/test_endpoints.py | 33 ++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index b036a47c..cb3d7996 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -56,3 +56,36 @@ def _sorting_key(d): assert sorted(documents, key=_sorting_key) == sorted( response.json(), key=_sorting_key ) + + +def test_get_document(tmp_local_root): + config = Config(local_root=tmp_local_root) + + document_root = config.local_root / "documents" + document_root.mkdir(exist_ok=True) + document_path = document_root / "test_get_document.txt" + with open(document_path, "w") as file: + file.write("Needs more reverb\n") + + with make_api_client( + config=Config(), ignore_unavailable_components=False + ) as client: + document = ( + client.post( + "/api/documents", + json=[{"name": document_path.name}], + ) + .raise_for_status() + .json()[0] + ) + + with open(document_path, "rb") as file: + client.put( + "/api/documents", + files={"documents": [(document["id"], file)]}, + ) + + response = client.get(f"/api/documents/{document['id']}") + response.raise_for_status() + + assert document == response.json() From d0cd74cb243a2a17c5d3492c0251705faf5bb988 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:12:59 -0800 Subject: [PATCH 26/49] Add test for `GET` document content --- tests/deploy/api/test_endpoints.py | 35 ++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index cb3d7996..5c2cbd1b 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -89,3 +89,38 @@ def test_get_document(tmp_local_root): response.raise_for_status() assert document == response.json() + + +def test_get_document_content(tmp_local_root): + config = Config(local_root=tmp_local_root) + + document_root = config.local_root / "documents" + document_root.mkdir(exist_ok=True) + document_path = document_root / "test_get_document_content.txt" + with open(document_path, "w") as file: + file.write("Needs more reverb\n") + + with make_api_client( + config=Config(), ignore_unavailable_components=False + ) as client: + document = ( + client.post( + "/api/documents", + json=[{"name": document_path.name}], + ) + .raise_for_status() + .json()[0] + ) + + with open(document_path, "rb") as file: + client.put( + "/api/documents", + files=[("documents", (document["id"], file))], + ) + + with client.stream( + "GET", f"/api/documents/{document['id']}/content" + ) as response: + received_lines = list(response.iter_lines()) + + assert received_lines == ["Needs more reverb"] From e107d5d8f72033b3a63ff14d7dc4a986a1beffe2 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:25:06 -0800 Subject: [PATCH 27/49] Fix typo in `PUT` methods --- tests/deploy/api/test_endpoints.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 5c2cbd1b..dbd2197c 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -38,12 +38,10 @@ def test_get_documents(tmp_local_root): ] client.put( "/api/documents", - files={ - "documents": [ - (document["id"], file) - for document, file in zip(documents, files) - ] - }, + files=[ + ("documents", (document["id"], file)) + for document, file in zip(documents, files) + ], ) response = client.get("/api/documents") @@ -82,7 +80,7 @@ def test_get_document(tmp_local_root): with open(document_path, "rb") as file: client.put( "/api/documents", - files={"documents": [(document["id"], file)]}, + files=[("documents", (document["id"], file))], ) response = client.get(f"/api/documents/{document['id']}") From 72d007f49d13d66ea8b1f4f2fae853a7120a5453 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Fri, 31 Jan 2025 01:26:11 -0800 Subject: [PATCH 28/49] Clean up `raise_for_status()` --- tests/deploy/api/test_endpoints.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index dbd2197c..0f6a47ae 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -44,8 +44,7 @@ def test_get_documents(tmp_local_root): ], ) - response = client.get("/api/documents") - response.raise_for_status() + response = client.get("/api/documents").raise_for_status() # Sort the items in case they are retrieved in different orders def _sorting_key(d): @@ -83,8 +82,7 @@ def test_get_document(tmp_local_root): files=[("documents", (document["id"], file))], ) - response = client.get(f"/api/documents/{document['id']}") - response.raise_for_status() + response = client.get(f"/api/documents/{document['id']}").raise_for_status() assert document == response.json() From f4028fd420dc0db32be9545887a765bec702537a Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:25:29 -0800 Subject: [PATCH 29/49] Use `__getitem__` instead of `next(iter(...))` --- ragna/deploy/_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index 0b97153d..aa81f2d0 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -197,7 +197,7 @@ def get_documents( return self._database.get_documents(session, user=user, ids=ids) def get_document(self, *, user: str, id: uuid.UUID) -> schemas.Document: - return next(iter(self.get_documents(user=user, ids=[id]))) + return self.get_documents(user=user, ids=[id])[0] def create_chat( self, *, user: str, chat_creation: schemas.ChatCreation From 6abfb0090d258dfe0050bef4e9ba293392b3b040 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:36:05 -0800 Subject: [PATCH 30/49] Remove unique names where appropriate --- tests/deploy/api/test_components.py | 4 ++-- tests/deploy/api/test_endpoints.py | 13 ++++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 79e32092..973c2c49 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -51,8 +51,8 @@ def test_unknown_component(tmp_local_root): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" - document_root.mkdir(exist_ok=True) - document_path = document_root / "test_unknown_component.txt" + document_root.mkdir() + document_path = document_root / "test.txt" with open(document_path, "w") as file: file.write("!\n") diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 0f6a47ae..a7022bbc 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -10,10 +10,9 @@ def test_get_documents(tmp_local_root): needs_more_of = ["reverb", "cowbell"] document_root = config.local_root / "documents" - document_root.mkdir(exist_ok=True) + document_root.mkdir() document_paths = [ - document_root / f"test_get_documents_{what_it_needs}.txt" - for what_it_needs in needs_more_of + document_root / f"test{counter}.txt" for counter in range(len(needs_more_of)) ] for what_it_needs, document_path in zip(needs_more_of, document_paths): with open(document_path, "w") as file: @@ -59,8 +58,8 @@ def test_get_document(tmp_local_root): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" - document_root.mkdir(exist_ok=True) - document_path = document_root / "test_get_document.txt" + document_root.mkdir() + document_path = document_root / "test.txt" with open(document_path, "w") as file: file.write("Needs more reverb\n") @@ -91,8 +90,8 @@ def test_get_document_content(tmp_local_root): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" - document_root.mkdir(exist_ok=True) - document_path = document_root / "test_get_document_content.txt" + document_root.mkdir() + document_path = document_root / "test.txt" with open(document_path, "w") as file: file.write("Needs more reverb\n") From 565ccabd28895ce3219014a25725f3693639ca02 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:37:45 -0800 Subject: [PATCH 31/49] Make sorting key not private --- tests/deploy/api/test_endpoints.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index a7022bbc..1fd095a2 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -46,11 +46,11 @@ def test_get_documents(tmp_local_root): response = client.get("/api/documents").raise_for_status() # Sort the items in case they are retrieved in different orders - def _sorting_key(d): + def sorting_key(d): return d["id"] - assert sorted(documents, key=_sorting_key) == sorted( - response.json(), key=_sorting_key + assert sorted(documents, key=sorting_key) == sorted( + response.json(), key=sorting_key ) From 3f0000f73552cf5097ea7098b55c02b7cfb08549 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sat, 1 Feb 2025 03:06:24 -0800 Subject: [PATCH 32/49] Add and use `upload_documents` to minimize repeated code in repeated tasks --- tests/deploy/api/test_endpoints.py | 74 +++++------------------------- tests/deploy/api/utils.py | 27 +++++++++++ 2 files changed, 39 insertions(+), 62 deletions(-) create mode 100644 tests/deploy/api/utils.py diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 1fd095a2..e3bba03b 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -1,12 +1,10 @@ -import contextlib - from ragna.deploy import Config +from tests.deploy.api.utils import upload_documents from tests.deploy.utils import make_api_client def test_get_documents(tmp_local_root): config = Config(local_root=tmp_local_root) - needs_more_of = ["reverb", "cowbell"] document_root = config.local_root / "documents" @@ -21,37 +19,16 @@ def test_get_documents(tmp_local_root): with make_api_client( config=Config(), ignore_unavailable_components=False ) as client: - documents = ( - client.post( - "/api/documents", - json=[{"name": document_path.name} for document_path in document_paths], - ) - .raise_for_status() - .json() - ) - - with contextlib.ExitStack() as stack: - files = [ - stack.enter_context(open(document_path, "rb")) - for document_path in document_paths - ] - client.put( - "/api/documents", - files=[ - ("documents", (document["id"], file)) - for document, file in zip(documents, files) - ], - ) - + documents = upload_documents(client=client, document_paths=document_paths) response = client.get("/api/documents").raise_for_status() - # Sort the items in case they are retrieved in different orders - def sorting_key(d): - return d["id"] + # Sort the items in case they are retrieved in different orders + def sorting_key(d): + return d["id"] - assert sorted(documents, key=sorting_key) == sorted( - response.json(), key=sorting_key - ) + assert sorted(documents, key=sorting_key) == sorted( + response.json(), key=sorting_key + ) def test_get_document(tmp_local_root): @@ -66,24 +43,10 @@ def test_get_document(tmp_local_root): with make_api_client( config=Config(), ignore_unavailable_components=False ) as client: - document = ( - client.post( - "/api/documents", - json=[{"name": document_path.name}], - ) - .raise_for_status() - .json()[0] - ) - - with open(document_path, "rb") as file: - client.put( - "/api/documents", - files=[("documents", (document["id"], file))], - ) - + document = upload_documents(client=client, document_paths=[document_path])[0] response = client.get(f"/api/documents/{document['id']}").raise_for_status() - assert document == response.json() + assert document == response.json() def test_get_document_content(tmp_local_root): @@ -98,24 +61,11 @@ def test_get_document_content(tmp_local_root): with make_api_client( config=Config(), ignore_unavailable_components=False ) as client: - document = ( - client.post( - "/api/documents", - json=[{"name": document_path.name}], - ) - .raise_for_status() - .json()[0] - ) - - with open(document_path, "rb") as file: - client.put( - "/api/documents", - files=[("documents", (document["id"], file))], - ) + document = upload_documents(client=client, document_paths=[document_path])[0] with client.stream( "GET", f"/api/documents/{document['id']}/content" ) as response: received_lines = list(response.iter_lines()) - assert received_lines == ["Needs more reverb"] + assert received_lines == ["Needs more reverb"] diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py new file mode 100644 index 00000000..2a5415a2 --- /dev/null +++ b/tests/deploy/api/utils.py @@ -0,0 +1,27 @@ +import contextlib + + +def upload_documents(*, client, document_paths): + documents = ( + client.post( + "/api/documents", + json=[{"name": document_path.name} for document_path in document_paths], + ) + .raise_for_status() + .json() + ) + + with contextlib.ExitStack() as stack: + files = [ + stack.enter_context(open(document_path, "rb")) + for document_path in document_paths + ] + client.put( + "/api/documents", + files=[ + ("documents", (document["id"], file)) + for document, file in zip(documents, files) + ], + ) + + return documents From 22a31b679ed783830a65228b0462b70b0b2a6d80 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sat, 1 Feb 2025 03:28:10 -0800 Subject: [PATCH 33/49] Store document text content in a variable to be reused --- tests/deploy/api/test_endpoints.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index e3bba03b..79d19eee 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -2,19 +2,22 @@ from tests.deploy.api.utils import upload_documents from tests.deploy.utils import make_api_client +_document_content_text = [ + f"Needs more {needs_more_of}\n" for needs_more_of in ["reverb", "cowbell"] +] + def test_get_documents(tmp_local_root): config = Config(local_root=tmp_local_root) - needs_more_of = ["reverb", "cowbell"] document_root = config.local_root / "documents" document_root.mkdir() document_paths = [ - document_root / f"test{counter}.txt" for counter in range(len(needs_more_of)) + document_root / f"test{idx}.txt" for idx in range(len(_document_content_text)) ] - for what_it_needs, document_path in zip(needs_more_of, document_paths): + for content, document_path in zip(_document_content_text, document_paths): with open(document_path, "w") as file: - file.write(f"Needs more {what_it_needs}\n") + file.write(content) with make_api_client( config=Config(), ignore_unavailable_components=False @@ -38,7 +41,7 @@ def test_get_document(tmp_local_root): document_root.mkdir() document_path = document_root / "test.txt" with open(document_path, "w") as file: - file.write("Needs more reverb\n") + file.write(_document_content_text[0]) with make_api_client( config=Config(), ignore_unavailable_components=False @@ -55,8 +58,9 @@ def test_get_document_content(tmp_local_root): document_root = config.local_root / "documents" document_root.mkdir() document_path = document_root / "test.txt" + document_content = _document_content_text[0] with open(document_path, "w") as file: - file.write("Needs more reverb\n") + file.write(document_content) with make_api_client( config=Config(), ignore_unavailable_components=False @@ -68,4 +72,4 @@ def test_get_document_content(tmp_local_root): ) as response: received_lines = list(response.iter_lines()) - assert received_lines == ["Needs more reverb"] + assert received_lines == [document_content.replace("\n", "")] From d6db90442ae563d074432f7acbf89a58b936b968 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sat, 1 Feb 2025 18:16:35 -0800 Subject: [PATCH 34/49] Use `upload_documents` in `test_components.py` --- tests/deploy/api/test_components.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index 973c2c49..fbbfc226 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -4,6 +4,7 @@ from ragna import assistants from ragna.core import RagnaException from ragna.deploy import Config +from tests.deploy.api.utils import upload_documents from tests.deploy.utils import make_api_app, make_api_client @@ -59,14 +60,7 @@ def test_unknown_component(tmp_local_root): with make_api_client( config=Config(), ignore_unavailable_components=False ) as client: - document = ( - client.post("/api/documents", json=[{"name": document_path.name}]) - .raise_for_status() - .json()[0] - ) - - with open(document_path, "rb") as file: - client.put("/api/documents", files={"documents": (document["id"], file)}) + document = upload_documents(client=client, document_paths=[document_path])[0] response = client.post( "/api/chats", @@ -80,7 +74,7 @@ def test_unknown_component(tmp_local_root): }, ) - assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - error = response.json()["error"] - assert "Unknown component" in error["message"] + error = response.json()["error"] + assert "Unknown component" in error["message"] From 435a268a2ee51e87e736a93d8032e3b01e8c5f4d Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sat, 1 Feb 2025 18:19:57 -0800 Subject: [PATCH 35/49] Fix typo --- tests/deploy/api/test_components.py | 4 +--- tests/deploy/api/test_endpoints.py | 12 +++--------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/deploy/api/test_components.py b/tests/deploy/api/test_components.py index fbbfc226..684a1a22 100644 --- a/tests/deploy/api/test_components.py +++ b/tests/deploy/api/test_components.py @@ -57,9 +57,7 @@ def test_unknown_component(tmp_local_root): with open(document_path, "w") as file: file.write("!\n") - with make_api_client( - config=Config(), ignore_unavailable_components=False - ) as client: + with make_api_client(config=config, ignore_unavailable_components=False) as client: document = upload_documents(client=client, document_paths=[document_path])[0] response = client.post( diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 79d19eee..0939afff 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -19,9 +19,7 @@ def test_get_documents(tmp_local_root): with open(document_path, "w") as file: file.write(content) - with make_api_client( - config=Config(), ignore_unavailable_components=False - ) as client: + with make_api_client(config=config, ignore_unavailable_components=False) as client: documents = upload_documents(client=client, document_paths=document_paths) response = client.get("/api/documents").raise_for_status() @@ -43,9 +41,7 @@ def test_get_document(tmp_local_root): with open(document_path, "w") as file: file.write(_document_content_text[0]) - with make_api_client( - config=Config(), ignore_unavailable_components=False - ) as client: + with make_api_client(config=config, ignore_unavailable_components=False) as client: document = upload_documents(client=client, document_paths=[document_path])[0] response = client.get(f"/api/documents/{document['id']}").raise_for_status() @@ -62,9 +58,7 @@ def test_get_document_content(tmp_local_root): with open(document_path, "w") as file: file.write(document_content) - with make_api_client( - config=Config(), ignore_unavailable_components=False - ) as client: + with make_api_client(config=config, ignore_unavailable_components=False) as client: document = upload_documents(client=client, document_paths=[document_path])[0] with client.stream( From a6acfd6cd6f42217f6bf4359ef4979d0254bdd53 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 01:43:23 -0800 Subject: [PATCH 36/49] Allow for specification of MIME types in `upload_documents` --- tests/deploy/api/utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index 2a5415a2..380573f7 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -1,11 +1,19 @@ import contextlib -def upload_documents(*, client, document_paths): +def upload_documents(*, client, document_paths, mime_types=None): + if mime_types is None: + mime_types = [None for _ in document_paths] documents = ( client.post( "/api/documents", - json=[{"name": document_path.name} for document_path in document_paths], + json=[ + { + "name": document_path.name, + "mime_type": mime_type, + } + for document_path, mime_type in zip(document_paths, mime_types) + ], ) .raise_for_status() .json() From 4cfa644a9aa7ff516f2347b1bd1481b9bd9470b8 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 01:44:08 -0800 Subject: [PATCH 37/49] Test with user-specified MIME types in `test_get_documents` --- tests/deploy/api/test_endpoints.py | 33 ++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 0939afff..dc75efe1 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -1,3 +1,7 @@ +import mimetypes + +import pytest + from ragna.deploy import Config from tests.deploy.api.utils import upload_documents from tests.deploy.utils import make_api_client @@ -7,7 +11,15 @@ ] -def test_get_documents(tmp_local_root): +@pytest.mark.parametrize( + ("mime_type",), + [ + (None,), # Let the mimetypes library decide + ("text/markdown",), + ("application/pdf",), + ], +) +def test_get_documents(tmp_local_root, mime_type): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" @@ -20,7 +32,11 @@ def test_get_documents(tmp_local_root): file.write(content) with make_api_client(config=config, ignore_unavailable_components=False) as client: - documents = upload_documents(client=client, document_paths=document_paths) + documents = upload_documents( + client=client, + document_paths=document_paths, + mime_types=[mime_type for _ in document_paths], + ) response = client.get("/api/documents").raise_for_status() # Sort the items in case they are retrieved in different orders @@ -31,6 +47,19 @@ def sorting_key(d): response.json(), key=sorting_key ) + for document, antwort in zip( + sorted(documents, key=sorting_key), sorted(response.json(), key=sorting_key) + ): + assert ( + document["mime_type"] + == antwort["mime_type"] + == ( + mime_type + if mime_type is not None + else mimetypes.guess_type(document_path.name)[0] + ) + ) + def test_get_document(tmp_local_root): config = Config(local_root=tmp_local_root) From 0959572867ef7b43042fd47725cb24e7c14976d2 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 01:49:59 -0800 Subject: [PATCH 38/49] Make `mime_types` parametrization reusable across multiple tests --- tests/deploy/api/test_endpoints.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index dc75efe1..f7516e9d 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -11,7 +11,7 @@ ] -@pytest.mark.parametrize( +mime_types = pytest.mark.parametrize( ("mime_type",), [ (None,), # Let the mimetypes library decide @@ -19,6 +19,9 @@ ("application/pdf",), ], ) + + +@mime_types def test_get_documents(tmp_local_root, mime_type): config = Config(local_root=tmp_local_root) From 2f9f2b75ff8b70b04364d6c7463fb527306595b5 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 01:55:22 -0800 Subject: [PATCH 39/49] Test with user-specified MIME types in `test_get_document` --- tests/deploy/api/test_endpoints.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index f7516e9d..5a277b1a 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -64,7 +64,8 @@ def sorting_key(d): ) -def test_get_document(tmp_local_root): +@mime_types +def test_get_document(tmp_local_root, mime_type): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" @@ -74,11 +75,25 @@ def test_get_document(tmp_local_root): file.write(_document_content_text[0]) with make_api_client(config=config, ignore_unavailable_components=False) as client: - document = upload_documents(client=client, document_paths=[document_path])[0] + document = upload_documents( + client=client, + document_paths=[document_path], + mime_types=[mime_type], + )[0] response = client.get(f"/api/documents/{document['id']}").raise_for_status() assert document == response.json() + assert ( + document["mime_type"] + == response.json()["mime_type"] + == ( + mime_type + if mime_type is not None + else mimetypes.guess_type(document_path.name)[0] + ) + ) + def test_get_document_content(tmp_local_root): config = Config(local_root=tmp_local_root) From a3456b704f670c74f8272ec622d277fe9c162a5a Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 02:04:31 -0800 Subject: [PATCH 40/49] Test with user-specified MIME types in `test_get_document_content` --- tests/deploy/api/test_endpoints.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 5a277b1a..fff97570 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -95,7 +95,8 @@ def test_get_document(tmp_local_root, mime_type): ) -def test_get_document_content(tmp_local_root): +@mime_types +def test_get_document_content(tmp_local_root, mime_type): config = Config(local_root=tmp_local_root) document_root = config.local_root / "documents" @@ -106,11 +107,26 @@ def test_get_document_content(tmp_local_root): file.write(document_content) with make_api_client(config=config, ignore_unavailable_components=False) as client: - document = upload_documents(client=client, document_paths=[document_path])[0] + document = upload_documents( + client=client, + document_paths=[document_path], + mime_types=[mime_type], + )[0] with client.stream( "GET", f"/api/documents/{document['id']}/content" ) as response: + response_mime_type = response.headers["content-type"].split(";")[0] received_lines = list(response.iter_lines()) assert received_lines == [document_content.replace("\n", "")] + + assert ( + document["mime_type"] + == response_mime_type + == ( + mime_type + if mime_type is not None + else mimetypes.guess_type(document_path.name)[0] + ) + ) From 4d1b1d5d24d63acb09e6a7612a57f4f809039a9c Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 20:38:02 -0800 Subject: [PATCH 41/49] Add `mime_type` to `DocumentRegistration` --- ragna/deploy/_schemas.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ragna/deploy/_schemas.py b/ragna/deploy/_schemas.py index 2c1072fe..2d9c50bd 100644 --- a/ragna/deploy/_schemas.py +++ b/ragna/deploy/_schemas.py @@ -78,6 +78,7 @@ class Components(BaseModel): class DocumentRegistration(BaseModel): name: str metadata: dict[str, Any] = Field(default_factory=dict) + mime_type: str | None = None class Document(BaseModel): From d5c6c16f5133936d867c3ad186087d3ace060f5c Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Mon, 3 Feb 2025 20:38:44 -0800 Subject: [PATCH 42/49] Include `mime_type` when registering documents --- ragna/deploy/_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ragna/deploy/_engine.py b/ragna/deploy/_engine.py index aa81f2d0..bc0b1afe 100644 --- a/ragna/deploy/_engine.py +++ b/ragna/deploy/_engine.py @@ -156,7 +156,9 @@ def register_documents( # We create core.Document's first, because they might update the metadata core_documents = [ self._config.document( - name=registration.name, metadata=registration.metadata + name=registration.name, + metadata=registration.metadata, + mime_type=registration.mime_type, ) for registration in document_registrations ] From 6cfb136f95d70b6d4b57b10b025ca97fdda9f450 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 01:03:55 -0800 Subject: [PATCH 43/49] Use `read` instead of `iter_lines` --- tests/deploy/api/test_endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index fff97570..d1d48f2c 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -117,9 +117,9 @@ def test_get_document_content(tmp_local_root, mime_type): "GET", f"/api/documents/{document['id']}/content" ) as response: response_mime_type = response.headers["content-type"].split(";")[0] - received_lines = list(response.iter_lines()) + received_text = response.read().decode("utf-8") - assert received_lines == [document_content.replace("\n", "")] + assert received_text == document_content assert ( document["mime_type"] From 5eafb6d6a5a4ba67f9c5b66f4a07360e295a1f46 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 01:21:27 -0800 Subject: [PATCH 44/49] Remove redundant assertion of equality --- tests/deploy/api/test_endpoints.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index d1d48f2c..05ff330d 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -50,17 +50,12 @@ def sorting_key(d): response.json(), key=sorting_key ) - for document, antwort in zip( - sorted(documents, key=sorting_key), sorted(response.json(), key=sorting_key) - ): - assert ( - document["mime_type"] - == antwort["mime_type"] - == ( - mime_type - if mime_type is not None - else mimetypes.guess_type(document_path.name)[0] - ) + # Assert that the correct MIME types are returned + for antwort in response.json(): + assert antwort["mime_type"] == ( + mime_type + if mime_type is not None + else mimetypes.guess_type(document_path.name)[0] ) @@ -84,14 +79,11 @@ def test_get_document(tmp_local_root, mime_type): assert document == response.json() - assert ( - document["mime_type"] - == response.json()["mime_type"] - == ( - mime_type - if mime_type is not None - else mimetypes.guess_type(document_path.name)[0] - ) + # Assert that the correct MIME type is returned + assert response.json()["mime_type"] == ( + mime_type + if mime_type is not None + else mimetypes.guess_type(document_path.name)[0] ) From 249c4879c2b8ea4aa38015c1a2f99472d9ef6a7e Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 01:41:56 -0800 Subject: [PATCH 45/49] Use `zip(..., strict=True)` to force arguments to be the same length --- tests/deploy/api/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index 380573f7..e34b322d 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -12,7 +12,11 @@ def upload_documents(*, client, document_paths, mime_types=None): "name": document_path.name, "mime_type": mime_type, } - for document_path, mime_type in zip(document_paths, mime_types) + for document_path, mime_type in zip( + document_paths, + mime_types, + strict=True, + ) ], ) .raise_for_status() From 5fd702e4a9a650c2b513490b5e2ba3a8326ab0c8 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 02:06:52 -0800 Subject: [PATCH 46/49] Remove assertion that should be part of other tests --- tests/deploy/api/test_endpoints.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 05ff330d..25f26882 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -50,14 +50,6 @@ def sorting_key(d): response.json(), key=sorting_key ) - # Assert that the correct MIME types are returned - for antwort in response.json(): - assert antwort["mime_type"] == ( - mime_type - if mime_type is not None - else mimetypes.guess_type(document_path.name)[0] - ) - @mime_types def test_get_document(tmp_local_root, mime_type): @@ -79,13 +71,6 @@ def test_get_document(tmp_local_root, mime_type): assert document == response.json() - # Assert that the correct MIME type is returned - assert response.json()["mime_type"] == ( - mime_type - if mime_type is not None - else mimetypes.guess_type(document_path.name)[0] - ) - @mime_types def test_get_document_content(tmp_local_root, mime_type): From 83e76079ffd8c0bb3b3c2e6aba30b7dc71664e13 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 02:33:42 -0800 Subject: [PATCH 47/49] Test equality of bytes, rather than of strings --- tests/deploy/api/test_endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index 25f26882..f5103278 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -94,9 +94,9 @@ def test_get_document_content(tmp_local_root, mime_type): "GET", f"/api/documents/{document['id']}/content" ) as response: response_mime_type = response.headers["content-type"].split(";")[0] - received_text = response.read().decode("utf-8") + received_bytes = response.read() - assert received_text == document_content + assert received_bytes == document_content.encode("utf-8") assert ( document["mime_type"] From 3a2d98e0abe3f2ca77bf4f6c328c2fc17e7e8548 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 02:37:19 -0800 Subject: [PATCH 48/49] Assert equal lengths instead of using `zip(..., strict=True)` --- tests/deploy/api/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/deploy/api/utils.py b/tests/deploy/api/utils.py index e34b322d..04352951 100644 --- a/tests/deploy/api/utils.py +++ b/tests/deploy/api/utils.py @@ -4,6 +4,8 @@ def upload_documents(*, client, document_paths, mime_types=None): if mime_types is None: mime_types = [None for _ in document_paths] + else: + assert len(mime_types) == len(document_paths) documents = ( client.post( "/api/documents", @@ -12,11 +14,7 @@ def upload_documents(*, client, document_paths, mime_types=None): "name": document_path.name, "mime_type": mime_type, } - for document_path, mime_type in zip( - document_paths, - mime_types, - strict=True, - ) + for document_path, mime_type in zip(document_paths, mime_types) ], ) .raise_for_status() From 68f302e5c3e1c160112e3a603ed3701c547373d8 Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Tue, 4 Feb 2025 18:34:44 -0800 Subject: [PATCH 49/49] Use `iter_lines` instead of receiving bytes --- tests/deploy/api/test_endpoints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/deploy/api/test_endpoints.py b/tests/deploy/api/test_endpoints.py index f5103278..5f40e3e8 100644 --- a/tests/deploy/api/test_endpoints.py +++ b/tests/deploy/api/test_endpoints.py @@ -94,9 +94,9 @@ def test_get_document_content(tmp_local_root, mime_type): "GET", f"/api/documents/{document['id']}/content" ) as response: response_mime_type = response.headers["content-type"].split(";")[0] - received_bytes = response.read() + received_lines = list(response.iter_lines()) - assert received_bytes == document_content.encode("utf-8") + assert received_lines == [document_content.replace("\n", "")] assert ( document["mime_type"]