From 93e1dfe1864af618dbe46749debe6bf6269a86a2 Mon Sep 17 00:00:00 2001 From: Pascal MERCIER Date: Wed, 19 Feb 2025 11:17:06 +0100 Subject: [PATCH] fix: AskFileButton can now upload file with proper checking and it's own limits --- backend/chainlit/emitter.py | 8 +- backend/chainlit/server.py | 59 ++++---- backend/chainlit/session.py | 3 +- backend/tests/conftest.py | 1 + backend/tests/test_server.py | 128 ++++++++++++++---- cypress/e2e/ask_file/main.py | 9 +- cypress/e2e/ask_multiple_files/main.py | 8 +- .../chat/Messages/Message/AskFileButton.tsx | 28 ++-- .../chat/MessagesContainer/index.tsx | 4 +- libs/react-client/src/api/index.tsx | 8 +- libs/react-client/src/useChatInteract.ts | 4 +- 11 files changed, 188 insertions(+), 72 deletions(-) diff --git a/backend/chainlit/emitter.py b/backend/chainlit/emitter.py index 42b21b58cb..5acb26c0c3 100644 --- a/backend/chainlit/emitter.py +++ b/backend/chainlit/emitter.py @@ -15,6 +15,7 @@ from chainlit.step import StepDict from chainlit.types import ( AskActionResponse, + AskFileSpec, AskSpec, CommandDict, FileDict, @@ -304,8 +305,11 @@ async def send_ask_user( self, step_dict: StepDict, spec: AskSpec, raise_on_timeout=False ): """Send a prompt to the UI and wait for a response.""" - + parent_id = str(step_dict["parentId"]) try: + if spec.type == "file": + self.session.files_spec[parent_id] = cast(AskFileSpec, spec) + # Send the prompt to the UI user_res = await self.emit_call( "ask", {"msg": step_dict, "spec": spec.to_dict()}, spec.timeout @@ -366,6 +370,8 @@ async def send_ask_user( if raise_on_timeout: raise e finally: + if parent_id in self.session.files_spec: + del self.session.files_spec[parent_id] await self.task_start() async def send_call_fn( diff --git a/backend/chainlit/server.py b/backend/chainlit/server.py index a49dbbf620..eba9341210 100644 --- a/backend/chainlit/server.py +++ b/backend/chainlit/server.py @@ -58,6 +58,7 @@ from chainlit.oauth_providers import get_oauth_provider from chainlit.secret import random_secret from chainlit.types import ( + AskFileSpec, CallActionRequest, DeleteFeedbackRequest, DeleteThreadRequest, @@ -1062,6 +1063,7 @@ async def upload_file( current_user: UserParam, session_id: str, file: UploadFile, + ask_parent_id: Optional[str] = None, ): """Upload a file to the session files directory.""" @@ -1089,8 +1091,15 @@ async def upload_file( assert file.filename, "No filename for uploaded file" assert file.content_type, "No content type for uploaded file" + spec: AskFileSpec = session.files_spec.get(ask_parent_id, None) + if not spec and ask_parent_id: + raise HTTPException( + status_code=404, + detail="Parent message not found", + ) + try: - validate_file_upload(file) + validate_file_upload(file, spec) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -1101,27 +1110,28 @@ async def upload_file( return JSONResponse(content=file_response) -def validate_file_upload(file: UploadFile): - """Validate the file upload as configured in config.features.spontaneous_file_upload. +def validate_file_upload(file: UploadFile, spec: Optional[AskFileSpec] = None): + """Validate the file upload as configured in config.features.spontaneous_file_upload or by AskFileSpec + for a specific message. + Args: file (UploadFile): The file to validate. + spec (AskFileSpec): The file spec to validate against if any. Raises: ValueError: If the file is not allowed. """ - # TODO: This logic/endpoint is shared across spontaneous uploads and the AskFileMessage API. - # Commenting this check until we find a better solution + if not spec and config.features.spontaneous_file_upload is None: + """Default for a missing config is to allow the fileupload without any restrictions""" + return - # if config.features.spontaneous_file_upload is None: - # """Default for a missing config is to allow the fileupload without any restrictions""" - # return - # if not config.features.spontaneous_file_upload.enabled: - # raise ValueError("File upload is not enabled") + if not spec and not config.features.spontaneous_file_upload.enabled: + raise ValueError("File upload is not enabled") - validate_file_mime_type(file) - validate_file_size(file) + validate_file_mime_type(file, spec) + validate_file_size(file, spec) -def validate_file_mime_type(file: UploadFile): +def validate_file_mime_type(file: UploadFile, spec: Optional[AskFileSpec]): """Validate the file mime type as configured in config.features.spontaneous_file_upload. Args: file (UploadFile): The file to validate. @@ -1129,14 +1139,14 @@ def validate_file_mime_type(file: UploadFile): ValueError: If the file type is not allowed. """ - if ( + if not spec and ( config.features.spontaneous_file_upload is None or config.features.spontaneous_file_upload.accept is None ): "Accept is not configured, allowing all file types" return - accept = config.features.spontaneous_file_upload.accept + accept = config.features.spontaneous_file_upload.accept if not spec else spec.accept assert isinstance(accept, List) or isinstance(accept, dict), ( "Invalid configuration for spontaneous_file_upload, accept must be a list or a dict" @@ -1144,11 +1154,11 @@ def validate_file_mime_type(file: UploadFile): if isinstance(accept, List): for pattern in accept: - if fnmatch.fnmatch(file.content_type, pattern): + if fnmatch.fnmatch(str(file.content_type), pattern): return elif isinstance(accept, dict): for pattern, extensions in accept.items(): - if fnmatch.fnmatch(file.content_type, pattern): + if fnmatch.fnmatch(str(file.content_type), pattern): if len(extensions) == 0: return for extension in extensions: @@ -1157,24 +1167,25 @@ def validate_file_mime_type(file: UploadFile): raise ValueError("File type not allowed") -def validate_file_size(file: UploadFile): +def validate_file_size(file: UploadFile, spec: Optional[AskFileSpec]): """Validate the file size as configured in config.features.spontaneous_file_upload. Args: file (UploadFile): The file to validate. Raises: ValueError: If the file size is too large. """ - if ( + if not spec and ( config.features.spontaneous_file_upload is None or config.features.spontaneous_file_upload.max_size_mb is None ): return - if ( - file.size is not None - and file.size - > config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024 - ): + max_size_mb = ( + config.features.spontaneous_file_upload.max_size_mb + if not spec + else spec.max_size_mb + ) + if file.size is not None and file.size > max_size_mb * 1024 * 1024: raise ValueError("File size too large") diff --git a/backend/chainlit/session.py b/backend/chainlit/session.py index d7127721d6..8750eebc18 100644 --- a/backend/chainlit/session.py +++ b/backend/chainlit/session.py @@ -8,7 +8,7 @@ import aiofiles from chainlit.logger import logger -from chainlit.types import FileReference +from chainlit.types import AskFileSpec, FileReference if TYPE_CHECKING: from chainlit.types import FileDict @@ -80,6 +80,7 @@ def __init__( self.http_cookie = http_cookie self.files: Dict[str, FileDict] = {} + self.files_spec: Dict[str, AskFileSpec] = {} self.id = id diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 62382958c7..d3f2aac229 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -42,6 +42,7 @@ def create_mock_session(**kwargs) -> Mock: mock.emit = AsyncMock() mock.has_first_interaction = kwargs.get("has_first_interaction", True) mock.files = kwargs.get("files", {}) + mock.files_spec = kwargs.get("files_spec", {}) return mock diff --git a/backend/tests/test_server.py b/backend/tests/test_server.py index f9f0422651..327dc478f2 100644 --- a/backend/tests/test_server.py +++ b/backend/tests/test_server.py @@ -15,6 +15,7 @@ SpontaneousFileUploadFeature, ) from chainlit.server import app +from chainlit.types import AskFileSpec from chainlit.user import PersistedUser @@ -500,36 +501,36 @@ def test_upload_file_unauthorized( assert response.status_code == 422 -# def test_upload_file_disabled( -# test_client: TestClient, -# test_config: ChainlitConfig, -# mock_session_get_by_id_patched: Mock, -# monkeypatch: pytest.MonkeyPatch, -# ): -# """Test file upload being disabled by config.""" +def test_upload_file_disabled( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, + monkeypatch: pytest.MonkeyPatch, +): + """Test file upload being disabled by config.""" -# # Set accept in config -# monkeypatch.setattr( -# test_config.features, -# "spontaneous_file_upload", -# SpontaneousFileUploadFeature(enabled=False), -# ) + # Set accept in config + monkeypatch.setattr( + test_config.features, + "spontaneous_file_upload", + SpontaneousFileUploadFeature(enabled=False), + ) -# # Prepare the files to upload -# file_content = b"Sample file content" -# files = { -# "file": ("test_upload.txt", file_content, "text/plain"), -# } + # Prepare the files to upload + file_content = b"Sample file content" + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } -# # Make the POST request to upload the file -# response = test_client.post( -# "/project/file", -# files=files, -# params={"session_id": mock_session_get_by_id_patched.id}, -# ) + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={"session_id": mock_session_get_by_id_patched.id}, + ) -# # Verify the response -# assert response.status_code == 400 + # Verify the response + assert response.status_code == 400 @pytest.mark.parametrize( @@ -639,7 +640,7 @@ def test_upload_file_size_check( monkeypatch.setattr( test_config.features, "spontaneous_file_upload", - SpontaneousFileUploadFeature(max_size_mb=max_size_mb), + SpontaneousFileUploadFeature(max_size_mb=max_size_mb, enabled=True), ) # Prepare the files to upload @@ -669,6 +670,79 @@ def test_upload_file_size_check( assert response.status_code == expected_status +@pytest.mark.parametrize( + ( + "file_content", + "content_multiplier", + "max_size_mb", + "parent_id", + "expected_status", + "accept", + ), + [ + (b"1", 1, 1, "mocked_parent_id", 200, ["text/plain"]), + (b"11", 1024 * 1024, 1, "mocked_parent_id", 400, ["text/plain"]), + (b"11", 1, 1, "invalid_parent_id", 404, ["text/plain"]), + (b"11", 1, 1, "mocked_parent_id", 400, ["image/gif"]), + ], +) +def test_ask_file_with_spontaneous_upload_disabled( + test_client: TestClient, + test_config: ChainlitConfig, + mock_session_get_by_id_patched: Mock, + monkeypatch: pytest.MonkeyPatch, + file_content: bytes, + content_multiplier: int, + max_size_mb: int, + parent_id: str, + expected_status: int, + accept: list[str], +): + """Test file upload being disabled by config.""" + + # Set accept in config + monkeypatch.setattr( + test_config.features, + "spontaneous_file_upload", + SpontaneousFileUploadFeature(enabled=False), + ) + + # Prepare the files to upload + file_content = file_content * content_multiplier + files = { + "file": ("test_upload.txt", file_content, "text/plain"), + } + + expected_file_id = "mocked_file_id" + mock_session_get_by_id_patched.persist_file = AsyncMock( + return_value={ + "id": expected_file_id, + "name": "test_upload.txt", + "type": "text/plain", + "size": len(file_content), + } + ) + + mock_session_get_by_id_patched.files_spec = { + "mocked_parent_id": AskFileSpec( + timeout=1, type="file", accept=accept, max_files=1, max_size_mb=max_size_mb + ) + } + + # Make the POST request to upload the file + response = test_client.post( + "/project/file", + files=files, + params={ + "session_id": mock_session_get_by_id_patched.id, + "ask_parent_id": parent_id, + }, + ) + + # Verify the response + assert response.status_code == expected_status + + def test_project_translations_file_path_traversal( test_client: TestClient, monkeypatch: pytest.MonkeyPatch ): diff --git a/cypress/e2e/ask_file/main.py b/cypress/e2e/ask_file/main.py index 427b61561f..877c5216e6 100644 --- a/cypress/e2e/ask_file/main.py +++ b/cypress/e2e/ask_file/main.py @@ -17,7 +17,14 @@ async def start(): ).send() files = await cl.AskFileMessage( - content="Please upload a python file to begin!", accept={"text/plain": [".py"]} + content="Please upload a python file to begin!", + accept={ + "text/plain": [".py", ".txt"], + # Some browser / os report it as text/plain but some as text/x-python when doing drag&drop + "text/x-python": [".py"], + # Or even as application/octet-stream when using the select file dialog + "application/octet-stream": [".py"], + }, ).send() py_file = files[0] diff --git a/cypress/e2e/ask_multiple_files/main.py b/cypress/e2e/ask_multiple_files/main.py index ddb7fdcbdc..bc76812a95 100644 --- a/cypress/e2e/ask_multiple_files/main.py +++ b/cypress/e2e/ask_multiple_files/main.py @@ -6,7 +6,13 @@ async def start(): files = await cl.AskFileMessage( content="Please upload from one to two python files to begin!", max_files=2, - accept={"text/plain": [".py"]}, + accept={ + "text/plain": [".py", ".txt"], + # Some browser / os report it as text/plain but some as text/x-python when doing drag&drop + "text/x-python": [".py"], + # Or even as application/octet-stream when using the select file dialog + "application/octet-stream": [".py"], + }, ).send() file_names = [file.name for file in files] diff --git a/frontend/src/components/chat/Messages/Message/AskFileButton.tsx b/frontend/src/components/chat/Messages/Message/AskFileButton.tsx index 6dcdc2dfa3..c9eb6a5255 100644 --- a/frontend/src/components/chat/Messages/Message/AskFileButton.tsx +++ b/frontend/src/components/chat/Messages/Message/AskFileButton.tsx @@ -19,9 +19,11 @@ interface UploadState { interface _AskFileButtonProps { askUser: IAsk; + parentId?: string; uploadFile: ( file: File, - onProgress: (progress: number) => void + onProgress: (progress: number) => void, + parentId?: string ) => { xhr: XMLHttpRequest; promise: Promise; @@ -90,16 +92,20 @@ const _AskFileButton = ({ const promises: Promise[] = []; const newUploads = files.map((file, index) => { - const { xhr, promise } = uploadFile(file, (progress) => { - setUploads((prev) => - prev.map((upload, i) => { - if (i === index) { - return { ...upload, progress }; - } - return upload; - }) - ); - }); + const { xhr, promise } = uploadFile( + file, + (progress) => { + setUploads((prev) => + prev.map((upload, i) => { + if (i === index) { + return { ...upload, progress }; + } + return upload; + }) + ); + }, + askUser?.parentId + ); promises.push(promise); return { progress: 0, uploaded: false, cancel: () => xhr.abort() }; }); diff --git a/frontend/src/components/chat/MessagesContainer/index.tsx b/frontend/src/components/chat/MessagesContainer/index.tsx index 7c4df3a521..5903c0aed5 100644 --- a/frontend/src/components/chat/MessagesContainer/index.tsx +++ b/frontend/src/components/chat/MessagesContainer/index.tsx @@ -36,8 +36,8 @@ const MessagesContainer = ({ navigate }: Props) => { const { t } = useTranslation(); const uploadFile = useCallback( - (file: File, onProgress: (progress: number) => void) => { - return _uploadFile(file, onProgress); + (file: File, onProgress: (progress: number) => void, parentId?: string) => { + return _uploadFile(file, onProgress, parentId); }, [_uploadFile] ); diff --git a/libs/react-client/src/api/index.tsx b/libs/react-client/src/api/index.tsx index 7df632f038..937b26478b 100644 --- a/libs/react-client/src/api/index.tsx +++ b/libs/react-client/src/api/index.tsx @@ -227,7 +227,8 @@ export class ChainlitAPI extends APIBase { uploadFile( file: File, onProgress: (progress: number) => void, - sessionId: string + sessionId: string, + parentId?: string ) { const xhr = new XMLHttpRequest(); xhr.withCredentials = true; @@ -236,9 +237,12 @@ export class ChainlitAPI extends APIBase { const formData = new FormData(); formData.append('file', file); + const ask_parent_id = parentId ? `&ask_parent_id=${parentId}` : ''; xhr.open( 'POST', - this.buildEndpoint(`/project/file?session_id=${sessionId}`), + this.buildEndpoint( + `/project/file?session_id=${sessionId}${ask_parent_id}` + ), true ); diff --git a/libs/react-client/src/useChatInteract.ts b/libs/react-client/src/useChatInteract.ts index bd5499ef9c..e18639da0e 100644 --- a/libs/react-client/src/useChatInteract.ts +++ b/libs/react-client/src/useChatInteract.ts @@ -152,8 +152,8 @@ const useChatInteract = () => { }, [session?.socket]); const uploadFile = useCallback( - (file: File, onProgress: (progress: number) => void) => { - return client.uploadFile(file, onProgress, sessionId); + (file: File, onProgress: (progress: number) => void, parentId?: string) => { + return client.uploadFile(file, onProgress, sessionId, parentId); }, [sessionId] );