Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: AskFileButton can now upload file with proper checking and it's own limits #1911

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backend/chainlit/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from chainlit.step import StepDict
from chainlit.types import (
AskActionResponse,
AskFileSpec,
AskSpec,
CommandDict,
FileDict,
Expand Down Expand Up @@ -304,8 +305,11 @@ async def send_ask_user(
self, step_dict: StepDict, spec: AskSpec, raise_on_timeout=False
):
"""Send a prompt to the UI and wait for a response."""

parent_id = str(step_dict["parentId"])
try:
if spec.type == "file":
self.session.files_spec[parent_id] = cast(AskFileSpec, spec)

# Send the prompt to the UI
user_res = await self.emit_call(
"ask", {"msg": step_dict, "spec": spec.to_dict()}, spec.timeout
Expand Down Expand Up @@ -366,6 +370,8 @@ async def send_ask_user(
if raise_on_timeout:
raise e
finally:
if parent_id in self.session.files_spec:
del self.session.files_spec[parent_id]
await self.task_start()

async def send_call_fn(
Expand Down
59 changes: 35 additions & 24 deletions backend/chainlit/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from chainlit.oauth_providers import get_oauth_provider
from chainlit.secret import random_secret
from chainlit.types import (
AskFileSpec,
CallActionRequest,
DeleteFeedbackRequest,
DeleteThreadRequest,
Expand Down Expand Up @@ -1062,6 +1063,7 @@ async def upload_file(
current_user: UserParam,
session_id: str,
file: UploadFile,
ask_parent_id: Optional[str] = None,
):
"""Upload a file to the session files directory."""

Expand Down Expand Up @@ -1089,8 +1091,15 @@ async def upload_file(
assert file.filename, "No filename for uploaded file"
assert file.content_type, "No content type for uploaded file"

spec: AskFileSpec = session.files_spec.get(ask_parent_id, None)
if not spec and ask_parent_id:
raise HTTPException(
status_code=404,
detail="Parent message not found",
)

try:
validate_file_upload(file)
validate_file_upload(file, spec)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

Expand All @@ -1101,54 +1110,55 @@ async def upload_file(
return JSONResponse(content=file_response)


def validate_file_upload(file: UploadFile):
"""Validate the file upload as configured in config.features.spontaneous_file_upload.
def validate_file_upload(file: UploadFile, spec: Optional[AskFileSpec] = None):
"""Validate the file upload as configured in config.features.spontaneous_file_upload or by AskFileSpec
for a specific message.

Args:
file (UploadFile): The file to validate.
spec (AskFileSpec): The file spec to validate against if any.
Raises:
ValueError: If the file is not allowed.
"""
# TODO: This logic/endpoint is shared across spontaneous uploads and the AskFileMessage API.
# Commenting this check until we find a better solution
if not spec and config.features.spontaneous_file_upload is None:
"""Default for a missing config is to allow the fileupload without any restrictions"""
return

# if config.features.spontaneous_file_upload is None:
# """Default for a missing config is to allow the fileupload without any restrictions"""
# return
# if not config.features.spontaneous_file_upload.enabled:
# raise ValueError("File upload is not enabled")
if not spec and not config.features.spontaneous_file_upload.enabled:
raise ValueError("File upload is not enabled")

validate_file_mime_type(file)
validate_file_size(file)
validate_file_mime_type(file, spec)
validate_file_size(file, spec)


def validate_file_mime_type(file: UploadFile):
def validate_file_mime_type(file: UploadFile, spec: Optional[AskFileSpec]):
"""Validate the file mime type as configured in config.features.spontaneous_file_upload.
Args:
file (UploadFile): The file to validate.
Raises:
ValueError: If the file type is not allowed.
"""

if (
if not spec and (
config.features.spontaneous_file_upload is None
or config.features.spontaneous_file_upload.accept is None
):
"Accept is not configured, allowing all file types"
return

accept = config.features.spontaneous_file_upload.accept
accept = config.features.spontaneous_file_upload.accept if not spec else spec.accept

assert isinstance(accept, List) or isinstance(accept, dict), (
"Invalid configuration for spontaneous_file_upload, accept must be a list or a dict"
)

if isinstance(accept, List):
for pattern in accept:
if fnmatch.fnmatch(file.content_type, pattern):
if fnmatch.fnmatch(str(file.content_type), pattern):
return
elif isinstance(accept, dict):
for pattern, extensions in accept.items():
if fnmatch.fnmatch(file.content_type, pattern):
if fnmatch.fnmatch(str(file.content_type), pattern):
if len(extensions) == 0:
return
for extension in extensions:
Expand All @@ -1157,24 +1167,25 @@ def validate_file_mime_type(file: UploadFile):
raise ValueError("File type not allowed")


def validate_file_size(file: UploadFile):
def validate_file_size(file: UploadFile, spec: Optional[AskFileSpec]):
"""Validate the file size as configured in config.features.spontaneous_file_upload.
Args:
file (UploadFile): The file to validate.
Raises:
ValueError: If the file size is too large.
"""
if (
if not spec and (
config.features.spontaneous_file_upload is None
or config.features.spontaneous_file_upload.max_size_mb is None
):
return

if (
file.size is not None
and file.size
> config.features.spontaneous_file_upload.max_size_mb * 1024 * 1024
):
max_size_mb = (
config.features.spontaneous_file_upload.max_size_mb
if not spec
else spec.max_size_mb
)
if file.size is not None and file.size > max_size_mb * 1024 * 1024:
raise ValueError("File size too large")


Expand Down
3 changes: 2 additions & 1 deletion backend/chainlit/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import aiofiles

from chainlit.logger import logger
from chainlit.types import FileReference
from chainlit.types import AskFileSpec, FileReference

if TYPE_CHECKING:
from chainlit.types import FileDict
Expand Down Expand Up @@ -80,6 +80,7 @@ def __init__(
self.http_cookie = http_cookie

self.files: Dict[str, FileDict] = {}
self.files_spec: Dict[str, AskFileSpec] = {}

self.id = id

Expand Down
1 change: 1 addition & 0 deletions backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def create_mock_session(**kwargs) -> Mock:
mock.emit = AsyncMock()
mock.has_first_interaction = kwargs.get("has_first_interaction", True)
mock.files = kwargs.get("files", {})
mock.files_spec = kwargs.get("files_spec", {})

return mock

Expand Down
128 changes: 101 additions & 27 deletions backend/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SpontaneousFileUploadFeature,
)
from chainlit.server import app
from chainlit.types import AskFileSpec
from chainlit.user import PersistedUser


Expand Down Expand Up @@ -500,36 +501,36 @@ def test_upload_file_unauthorized(
assert response.status_code == 422


# def test_upload_file_disabled(
# test_client: TestClient,
# test_config: ChainlitConfig,
# mock_session_get_by_id_patched: Mock,
# monkeypatch: pytest.MonkeyPatch,
# ):
# """Test file upload being disabled by config."""
def test_upload_file_disabled(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
):
"""Test file upload being disabled by config."""

# # Set accept in config
# monkeypatch.setattr(
# test_config.features,
# "spontaneous_file_upload",
# SpontaneousFileUploadFeature(enabled=False),
# )
# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(enabled=False),
)

# # Prepare the files to upload
# file_content = b"Sample file content"
# files = {
# "file": ("test_upload.txt", file_content, "text/plain"),
# }
# Prepare the files to upload
file_content = b"Sample file content"
files = {
"file": ("test_upload.txt", file_content, "text/plain"),
}

# # Make the POST request to upload the file
# response = test_client.post(
# "/project/file",
# files=files,
# params={"session_id": mock_session_get_by_id_patched.id},
# )
# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={"session_id": mock_session_get_by_id_patched.id},
)

# # Verify the response
# assert response.status_code == 400
# Verify the response
assert response.status_code == 400


@pytest.mark.parametrize(
Expand Down Expand Up @@ -639,7 +640,7 @@ def test_upload_file_size_check(
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(max_size_mb=max_size_mb),
SpontaneousFileUploadFeature(max_size_mb=max_size_mb, enabled=True),
)

# Prepare the files to upload
Expand Down Expand Up @@ -669,6 +670,79 @@ def test_upload_file_size_check(
assert response.status_code == expected_status


@pytest.mark.parametrize(
(
"file_content",
"content_multiplier",
"max_size_mb",
"parent_id",
"expected_status",
"accept",
),
[
(b"1", 1, 1, "mocked_parent_id", 200, ["text/plain"]),
(b"11", 1024 * 1024, 1, "mocked_parent_id", 400, ["text/plain"]),
(b"11", 1, 1, "invalid_parent_id", 404, ["text/plain"]),
(b"11", 1, 1, "mocked_parent_id", 400, ["image/gif"]),
],
)
def test_ask_file_with_spontaneous_upload_disabled(
test_client: TestClient,
test_config: ChainlitConfig,
mock_session_get_by_id_patched: Mock,
monkeypatch: pytest.MonkeyPatch,
file_content: bytes,
content_multiplier: int,
max_size_mb: int,
parent_id: str,
expected_status: int,
accept: list[str],
):
"""Test file upload being disabled by config."""

# Set accept in config
monkeypatch.setattr(
test_config.features,
"spontaneous_file_upload",
SpontaneousFileUploadFeature(enabled=False),
)

# Prepare the files to upload
file_content = file_content * content_multiplier
files = {
"file": ("test_upload.txt", file_content, "text/plain"),
}

expected_file_id = "mocked_file_id"
mock_session_get_by_id_patched.persist_file = AsyncMock(
return_value={
"id": expected_file_id,
"name": "test_upload.txt",
"type": "text/plain",
"size": len(file_content),
}
)

mock_session_get_by_id_patched.files_spec = {
"mocked_parent_id": AskFileSpec(
timeout=1, type="file", accept=accept, max_files=1, max_size_mb=max_size_mb
)
}

# Make the POST request to upload the file
response = test_client.post(
"/project/file",
files=files,
params={
"session_id": mock_session_get_by_id_patched.id,
"ask_parent_id": parent_id,
},
)

# Verify the response
assert response.status_code == expected_status


def test_project_translations_file_path_traversal(
test_client: TestClient, monkeypatch: pytest.MonkeyPatch
):
Expand Down
9 changes: 8 additions & 1 deletion cypress/e2e/ask_file/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ async def start():
).send()

files = await cl.AskFileMessage(
content="Please upload a python file to begin!", accept={"text/plain": [".py"]}
content="Please upload a python file to begin!",
accept={
"text/plain": [".py", ".txt"],
# Some browser / os report it as text/plain but some as text/x-python when doing drag&drop
"text/x-python": [".py"],
# Or even as application/octet-stream when using the select file dialog
"application/octet-stream": [".py"],
},
).send()
py_file = files[0]

Expand Down
8 changes: 7 additions & 1 deletion cypress/e2e/ask_multiple_files/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ async def start():
files = await cl.AskFileMessage(
content="Please upload from one to two python files to begin!",
max_files=2,
accept={"text/plain": [".py"]},
accept={
"text/plain": [".py", ".txt"],
# Some browser / os report it as text/plain but some as text/x-python when doing drag&drop
"text/x-python": [".py"],
# Or even as application/octet-stream when using the select file dialog
"application/octet-stream": [".py"],
},
).send()

file_names = [file.name for file in files]
Expand Down
Loading
Loading