diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..a1154a08a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +**/__pycache__/ +.pytest_cache/ +.git +.github +.vscode +.DS_Store +.env +docs +examples +tests +x diff --git a/docker/Dockerfile.api b/docker/Dockerfile.api new file mode 100644 index 000000000..0a9d118e7 --- /dev/null +++ b/docker/Dockerfile.api @@ -0,0 +1,25 @@ +FROM python:3.12 + +WORKDIR /app + +ARG ELL_EXTRAS="api-server postgres mqtt" + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install Poetry +RUN pip install --no-cache-dir poetry + +# Copy only requirements to cache them in docker layer +COPY pyproject.toml poetry.lock* ./ + +# Project initialization: +RUN poetry config virtualenvs.create false \ + && poetry install --extras="api-server ${ELL_EXTRAS}" --no-interaction --no-ansi + +# Copy project +COPY src . + +CMD ["python", "-m", "ell.api"] diff --git a/docker/Dockerfile.studio b/docker/Dockerfile.studio new file mode 100644 index 000000000..76217e336 --- /dev/null +++ b/docker/Dockerfile.studio @@ -0,0 +1,49 @@ +# Start with a Node.js base image for building the React app +FROM node:20 AS client-builder + +WORKDIR /app/ell-studio + +# Copy package.json and package-lock.json (if available) +COPY ell-studio/package.json ell-studio/package-lock.json* ./ + +# Install dependencies +RUN npm ci + +# Copy the rest of the client code +COPY ell-studio . + +# Build the React app +RUN npm run build + +# Now, start with the Python base image +FROM python:3.12 + +ARG ELL_EXTRAS="" + +RUN echo "ELL_EXTRAS=${ELL_EXTRAS}" + +WORKDIR /app + + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + gcc \ + && rm -rf /var/lib/apt/lists/* + +# Install Poetry +RUN pip install --no-cache-dir poetry + +# Copy only requirements to cache them in docker layer +COPY pyproject.toml poetry.lock* ./ + +# Project initialization: +RUN poetry config virtualenvs.create false \ + && poetry install --extras="studio ${ELL_EXTRAS}" --no-interaction --no-ansi + +# Copy the Python project +COPY src . + +# Copy the built React app from the client-builder stage +COPY --from=client-builder /app/ell-studio/build /app/ell/studio/static + +CMD ["python", "-m", "ell.studio"] \ No newline at end of file diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml new file mode 100644 index 000000000..def2b9d77 --- /dev/null +++ b/docker/docker-compose.yml @@ -0,0 +1,115 @@ +name: ell +version: "3.9" +services: + api: + build: + context: .. + dockerfile: docker/Dockerfile.api + args: + ELL_EXTRAS: postgres mqtt minio + tags: + - ell-api + + ports: + - "8081:8081" + environment: + - ELL_API_HOST=0.0.0.0 + - ELL_API_PORT=8081 + - ELL_PG_CONNECTION_STRING=postgresql://ell_user:ell_password@postgres:5432/ell_db + - ELL_MQTT_CONNECTION_STRING=mqtt://mqtt:1883 + - LOG_LEVEL=10 # debug + - ELL_MINIO_ENDPOINT=minio:9000 + - ELL_MINIO_ACCESS_KEY=minio_user + - ELL_MINIO_SECRET_KEY=minio_password + - ELL_MINIO_BUCKET=ell-bucket + depends_on: + - postgres + - mqtt + - minio + + studio: + build: + context: .. + dockerfile: docker/Dockerfile.studio + args: + ELL_EXTRAS: postgres mqtt minio + tags: + - ell-studio + ports: + - "8080:8080" + environment: + - ELL_STUDIO_HOST=0.0.0.0 + - ELL_STUDIO_PORT=8080 + - ELL_PG_CONNECTION_STRING=postgresql://ell_user:ell_password@postgres:5432/ell_db + - ELL_MQTT_CONNECTION_STRING=mqtt://mqtt:1883 + - ELL_MINIO_ENDPOINT=minio:9000 + - ELL_MINIO_ACCESS_KEY=minio_user + - ELL_MINIO_SECRET_KEY=minio_password + - ELL_MINIO_BUCKET=ell-bucket + depends_on: + - postgres + - mqtt + - minio + develop: + watch: + - action: sync+restart + path: ./src/ell/studio + target: /app/ell/studio + + mqtt: + image: eclipse-mosquitto:latest + ports: + - "1883:1883" + command: mosquitto -c /mosquitto/config/mosquitto.conf + volumes: + - mosquitto_config:/mosquitto/config + depends_on: + - mqtt-config + + mqtt-config: + image: busybox + volumes: + - mosquitto_config:/mosquitto/config + command: > + sh -c "echo 'listener 1883' > /mosquitto/config/mosquitto.conf && + echo 'allow_anonymous true' >> /mosquitto/config/mosquitto.conf" + + postgres: + image: postgres:16 + environment: + - POSTGRES_USER=ell_user + - POSTGRES_PASSWORD=ell_password + - POSTGRES_DB=ell_db + volumes: + - postgres_data:/var/lib/postgresql/data + ports: + - "5432:5432" + + minio: + image: minio/minio:latest + ports: + - "9000:9000" # API port + - "9001:9001" # Console port + environment: + - MINIO_ROOT_USER=minio_user + - MINIO_ROOT_PASSWORD=minio_password + volumes: + - minio_data:/data + command: server --console-address ":9001" --address ":9000" /data + + minio-init: + image: minio/mc + depends_on: + - minio + entrypoint: > + /bin/sh -c " + sleep 5; + /usr/bin/mc alias set myminio http://minio:9000 minio_user minio_password --api S3v4; + /usr/bin/mc mb myminio/ell-bucket; + exit 0; + " + +volumes: + postgres_data: + mosquitto_config: + minio_data: \ No newline at end of file diff --git a/examples/future/http_serializer.py b/examples/future/http_serializer.py new file mode 100644 index 000000000..20844f3cf --- /dev/null +++ b/examples/future/http_serializer.py @@ -0,0 +1,25 @@ +from pydantic import Field +import ell + +ell.init(api_url='http://localhost:8081') + +@ell.tool() +def get_weather(location: str = Field(description="The full name of a city and country, e.g. San Francisco, CA, USA")): + """Get the current weather for a given location.""" + # Simulated weather API call + return f"The weather in {location} is sunny." + +@ell.complex(model="gpt-4o", tools=[get_weather]) +def travel_planner(destination: str): + """Plan a trip based on the destination and current weather.""" + return [ + ell.system("You are a travel planner. Use the weather tool to provide relevant advice."), + ell.user(f"Plan a trip to {destination}") + ] + +result = travel_planner("Paris") +print(result.text) # Prints travel advice +if result.tool_calls: + # This is done so that we can pass the tool calls to the language model + tool_results = result.call_tools_and_collect_as_message() + print("Weather info:", (tool_results.text)) diff --git a/examples/future/images_minio.py b/examples/future/images_minio.py new file mode 100644 index 000000000..cc0040d60 --- /dev/null +++ b/examples/future/images_minio.py @@ -0,0 +1,38 @@ +from PIL import Image +import os + +import ell +from ell.stores.minio import MinioBlobStore, MinioConfig +from ell.stores.sql import PostgresStore + + +# Load the image using PIL +big_picture = Image.open(os.path.join(os.path.dirname(__file__), "bigpicture.jpg")) + +@ell.simple(model="gpt-4o", temperature=0.5) +def make_a_joke_about_the_image(image: Image.Image): + return [ + ell.system("You are a meme maker. You are given an image and you must make a joke about it."), + ell.user(image) + ] + + + +if __name__ == "__main__": + # Run "docker compose up" inside the `docker` folder to run + # ell studio with minio for blob storage with postgres + blob_store = MinioBlobStore( + config=MinioConfig( + endpoint="localhost:9000", + access_key="minio_user", + secret_key="minio_password", + bucket="ell-bucket", + ) + ) + store = PostgresStore( + db_uri="postgresql://ell_user:ell_password@localhost:5432/ell_db", + blob_store=blob_store, + ) + ell.init(store=store, autocommit=True, verbose=True) + joke = make_a_joke_about_the_image(big_picture) + print(joke) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 86afded79..6a759e5b7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,19 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. + +[[package]] +name = "aiomqtt" +version = "2.3.0" +description = "The idiomatic asyncio MQTT client, wrapped around paho-mqtt" +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "aiomqtt-2.3.0-py3-none-any.whl", hash = "sha256:127926717bd6b012d1630f9087f24552eb9c4af58205bc2964f09d6e304f7e63"}, + {file = "aiomqtt-2.3.0.tar.gz", hash = "sha256:312feebe20bc76dc7c20916663011f3bd37aa6f42f9f687a19a1c58308d80d47"}, +] + +[package.dependencies] +paho-mqtt = ">=2.1.0,<3.0.0" +typing-extensions = {version = ">=4.4.0,<5.0.0", markers = "python_version < \"3.10\""} [[package]] name = "alabaster" @@ -88,6 +103,63 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] trio = ["trio (>=0.26.1)"] +[[package]] +name = "argon2-cffi" +version = "23.1.0" +description = "Argon2 for Python" +optional = true +python-versions = ">=3.7" +files = [ + {file = "argon2_cffi-23.1.0-py3-none-any.whl", hash = "sha256:c670642b78ba29641818ab2e68bd4e6a78ba53b7eff7b4c3815ae16abf91c7ea"}, + {file = "argon2_cffi-23.1.0.tar.gz", hash = "sha256:879c3e79a2729ce768ebb7d36d4609e3a78a4ca2ec3a9f12286ca057e3d0db08"}, +] + +[package.dependencies] +argon2-cffi-bindings = "*" + +[package.extras] +dev = ["argon2-cffi[tests,typing]", "tox (>4)"] +docs = ["furo", "myst-parser", "sphinx", "sphinx-copybutton", "sphinx-notfound-page"] +tests = ["hypothesis", "pytest"] +typing = ["mypy"] + +[[package]] +name = "argon2-cffi-bindings" +version = "21.2.0" +description = "Low-level CFFI bindings for Argon2" +optional = true +python-versions = ">=3.6" +files = [ + {file = "argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082"}, + {file = "argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f"}, + {file = "argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3e385d1c39c520c08b53d63300c3ecc28622f076f4c2b0e6d7e796e9f6502194"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2c3e3cc67fdb7d82c4718f19b4e7a87123caf8a93fde7e23cf66ac0337d3cb3f"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6a22ad9800121b71099d0fb0a65323810a15f2e292f2ba450810a7316e128ee5"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f9f8b450ed0547e3d473fdc8612083fd08dd2120d6ac8f73828df9b7d45bb351"}, + {file = "argon2_cffi_bindings-21.2.0-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:93f9bf70084f97245ba10ee36575f0c3f1e7d7724d67d8e5b08e61787c320ed7"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3b9ef65804859d335dc6b31582cad2c5166f0c3e7975f324d9ffaa34ee7e6583"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d4966ef5848d820776f5f562a7d45fdd70c2f330c961d0d745b784034bd9f48d"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20ef543a89dee4db46a1a6e206cd015360e5a75822f76df533845c3cbaf72670"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed2937d286e2ad0cc79a7087d3c272832865f779430e0cc2b4f3718d3159b0cb"}, + {file = "argon2_cffi_bindings-21.2.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:5e00316dabdaea0b2dd82d141cc66889ced0cdcbfa599e8b471cf22c620c329a"}, +] + +[package.dependencies] +cffi = ">=1.0.1" + +[package.extras] +dev = ["cogapp", "pre-commit", "pytest", "wheel"] +tests = ["pytest"] + [[package]] name = "attrs" version = "24.2.0" @@ -203,6 +275,85 @@ files = [ {file = "certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9"}, ] +[[package]] +name = "cffi" +version = "1.17.1" +description = "Foreign Function Interface for Python calling C code." +optional = true +python-versions = ">=3.8" +files = [ + {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, + {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5da5719280082ac6bd9aa7becb3938dc9f9cbd57fac7d2871717b1feb0902ab6"}, + {file = "cffi-1.17.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb1a08b8008b281856e5971307cc386a8e9c5b625ac297e853d36da6efe9c17"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e"}, + {file = "cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be"}, + {file = "cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c"}, + {file = "cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401"}, + {file = "cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a24ed04c8ffd54b0729c07cee15a81d964e6fee0e3d4d342a27b020d22959dc6"}, + {file = "cffi-1.17.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:610faea79c43e44c71e1ec53a554553fa22321b65fae24889706c0a84d4ad86d"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f"}, + {file = "cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b"}, + {file = "cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655"}, + {file = "cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4"}, + {file = "cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f79fc4fc25f1c8698ff97788206bb3c2598949bfe0fef03d299eb1b5356ada99"}, + {file = "cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3"}, + {file = "cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8"}, + {file = "cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65"}, + {file = "cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f3a2b4222ce6b60e2e8b337bb9596923045681d71e5a082783484d845390938e"}, + {file = "cffi-1.17.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0984a4925a435b1da406122d4d7968dd861c1385afe3b45ba82b750f229811e2"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d01b12eeeb4427d3110de311e1774046ad344f5b1a7403101878976ecd7a10f3"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:706510fe141c86a69c8ddc029c7910003a17353970cff3b904ff0686a5927683"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de55b766c7aa2e2a3092c51e0483d700341182f08e67c63630d5b6f200bb28e5"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c59d6e989d07460165cc5ad3c61f9fd8f1b4796eacbd81cee78957842b834af4"}, + {file = "cffi-1.17.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd398dbc6773384a17fe0d3e7eeb8d1a21c2200473ee6806bb5e6a8e62bb73dd"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:3edc8d958eb099c634dace3c7e16560ae474aa3803a5df240542b305d14e14ed"}, + {file = "cffi-1.17.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:72e72408cad3d5419375fc87d289076ee319835bdfa2caad331e377589aebba9"}, + {file = "cffi-1.17.1-cp313-cp313-win32.whl", hash = "sha256:e03eab0a8677fa80d646b5ddece1cbeaf556c313dcfac435ba11f107ba117b5d"}, + {file = "cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a"}, + {file = "cffi-1.17.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:636062ea65bd0195bc012fea9321aca499c0504409f413dc88af450b57ffd03b"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c7eac2ef9b63c79431bc4b25f1cd649d7f061a28808cbc6c47b534bd789ef964"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e221cf152cff04059d011ee126477f0d9588303eb57e88923578ace7baad17f9"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:31000ec67d4221a71bd3f67df918b1f88f676f1c3b535a7eb473255fdc0b83fc"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f17be4345073b0a7b8ea599688f692ac3ef23ce28e5df79c04de519dbc4912c"}, + {file = "cffi-1.17.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e2b1fac190ae3ebfe37b979cc1ce69c81f4e4fe5746bb401dca63a9062cdaf1"}, + {file = "cffi-1.17.1-cp38-cp38-win32.whl", hash = "sha256:7596d6620d3fa590f677e9ee430df2958d2d6d6de2feeae5b20e82c00b76fbf8"}, + {file = "cffi-1.17.1-cp38-cp38-win_amd64.whl", hash = "sha256:78122be759c3f8a014ce010908ae03364d00a1f81ab5c7f4a7a5120607ea56e1"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16"}, + {file = "cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:98e3969bcff97cae1b2def8ba499ea3d6f31ddfdb7635374834cf89a1a08ecf0"}, + {file = "cffi-1.17.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cdf5ce3acdfd1661132f2a9c19cac174758dc2352bfe37d98aa7512c6b7178b3"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a"}, + {file = "cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e"}, + {file = "cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7"}, + {file = "cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662"}, + {file = "cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824"}, +] + +[package.dependencies] +pycparser = "*" + [[package]] name = "charset-normalizer" version = "3.4.0" @@ -1059,6 +1210,24 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "minio" +version = "7.2.10" +description = "MinIO Python SDK for Amazon S3 Compatible Cloud Storage" +optional = true +python-versions = ">3.8" +files = [ + {file = "minio-7.2.10-py3-none-any.whl", hash = "sha256:5961c58192b1d70d3a2a362064b8e027b8232688998a6d1251dadbb02ab57a7d"}, + {file = "minio-7.2.10.tar.gz", hash = "sha256:418c31ac79346a580df04a0e14db1becbc548a6e7cca61f9bc4ef3bcd336c449"}, +] + +[package.dependencies] +argon2-cffi = "*" +certifi = "*" +pycryptodome = "*" +typing-extensions = "*" +urllib3 = "*" + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -1159,6 +1328,20 @@ files = [ {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, ] +[[package]] +name = "paho-mqtt" +version = "2.1.0" +description = "MQTT version 5.0/3.1.1 client class" +optional = true +python-versions = ">=3.7" +files = [ + {file = "paho_mqtt-2.1.0-py3-none-any.whl", hash = "sha256:6db9ba9b34ed5bc6b6e3812718c7e06e2fd7444540df2455d2c51bd58808feee"}, + {file = "paho_mqtt-2.1.0.tar.gz", hash = "sha256:12d6e7511d4137555a3f6ea167ae846af2c7357b10bc6fa4f7c3968fc1723834"}, +] + +[package.extras] +proxy = ["pysocks"] + [[package]] name = "pathspec" version = "0.12.1" @@ -1344,6 +1527,58 @@ files = [ {file = "psycopg2-2.9.10.tar.gz", hash = "sha256:12ec0b40b0273f95296233e8750441339298e6a572f7039da5b260e3c8b60e11"}, ] +[[package]] +name = "pycparser" +version = "2.22" +description = "C parser in Python" +optional = true +python-versions = ">=3.8" +files = [ + {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, + {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, +] + +[[package]] +name = "pycryptodome" +version = "3.21.0" +description = "Cryptographic library for Python" +optional = true +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +files = [ + {file = "pycryptodome-3.21.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:dad9bf36eda068e89059d1f07408e397856be9511d7113ea4b586642a429a4fd"}, + {file = "pycryptodome-3.21.0-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:a1752eca64c60852f38bb29e2c86fca30d7672c024128ef5d70cc15868fa10f4"}, + {file = "pycryptodome-3.21.0-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:3ba4cc304eac4d4d458f508d4955a88ba25026890e8abff9b60404f76a62c55e"}, + {file = "pycryptodome-3.21.0-cp27-cp27m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cb087b8612c8a1a14cf37dd754685be9a8d9869bed2ffaaceb04850a8aeef7e"}, + {file = "pycryptodome-3.21.0-cp27-cp27m-musllinux_1_1_aarch64.whl", hash = "sha256:26412b21df30b2861424a6c6d5b1d8ca8107612a4cfa4d0183e71c5d200fb34a"}, + {file = "pycryptodome-3.21.0-cp27-cp27m-win32.whl", hash = "sha256:cc2269ab4bce40b027b49663d61d816903a4bd90ad88cb99ed561aadb3888dd3"}, + {file = "pycryptodome-3.21.0-cp27-cp27m-win_amd64.whl", hash = "sha256:0fa0a05a6a697ccbf2a12cec3d6d2650b50881899b845fac6e87416f8cb7e87d"}, + {file = "pycryptodome-3.21.0-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6cce52e196a5f1d6797ff7946cdff2038d3b5f0aba4a43cb6bf46b575fd1b5bb"}, + {file = "pycryptodome-3.21.0-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:a915597ffccabe902e7090e199a7bf7a381c5506a747d5e9d27ba55197a2c568"}, + {file = "pycryptodome-3.21.0-cp27-cp27mu-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4e74c522d630766b03a836c15bff77cb657c5fdf098abf8b1ada2aebc7d0819"}, + {file = "pycryptodome-3.21.0-cp27-cp27mu-musllinux_1_1_aarch64.whl", hash = "sha256:a3804675283f4764a02db05f5191eb8fec2bb6ca34d466167fc78a5f05bbe6b3"}, + {file = "pycryptodome-3.21.0-cp36-abi3-macosx_10_9_universal2.whl", hash = "sha256:2480ec2c72438430da9f601ebc12c518c093c13111a5c1644c82cdfc2e50b1e4"}, + {file = "pycryptodome-3.21.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:de18954104667f565e2fbb4783b56667f30fb49c4d79b346f52a29cb198d5b6b"}, + {file = "pycryptodome-3.21.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de4b7263a33947ff440412339cb72b28a5a4c769b5c1ca19e33dd6cd1dcec6e"}, + {file = "pycryptodome-3.21.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0714206d467fc911042d01ea3a1847c847bc10884cf674c82e12915cfe1649f8"}, + {file = "pycryptodome-3.21.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d85c1b613121ed3dbaa5a97369b3b757909531a959d229406a75b912dd51dd1"}, + {file = "pycryptodome-3.21.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:8898a66425a57bcf15e25fc19c12490b87bd939800f39a03ea2de2aea5e3611a"}, + {file = "pycryptodome-3.21.0-cp36-abi3-musllinux_1_2_i686.whl", hash = "sha256:932c905b71a56474bff8a9c014030bc3c882cee696b448af920399f730a650c2"}, + {file = "pycryptodome-3.21.0-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:18caa8cfbc676eaaf28613637a89980ad2fd96e00c564135bf90bc3f0b34dd93"}, + {file = "pycryptodome-3.21.0-cp36-abi3-win32.whl", hash = "sha256:280b67d20e33bb63171d55b1067f61fbd932e0b1ad976b3a184303a3dad22764"}, + {file = "pycryptodome-3.21.0-cp36-abi3-win_amd64.whl", hash = "sha256:b7aa25fc0baa5b1d95b7633af4f5f1838467f1815442b22487426f94e0d66c53"}, + {file = "pycryptodome-3.21.0-pp27-pypy_73-manylinux2010_x86_64.whl", hash = "sha256:2cb635b67011bc147c257e61ce864879ffe6d03342dc74b6045059dfbdedafca"}, + {file = "pycryptodome-3.21.0-pp27-pypy_73-win32.whl", hash = "sha256:4c26a2f0dc15f81ea3afa3b0c87b87e501f235d332b7f27e2225ecb80c0b1cdd"}, + {file = "pycryptodome-3.21.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d5ebe0763c982f069d3877832254f64974139f4f9655058452603ff559c482e8"}, + {file = "pycryptodome-3.21.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ee86cbde706be13f2dec5a42b52b1c1d1cbb90c8e405c68d0755134735c8dc6"}, + {file = "pycryptodome-3.21.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0fd54003ec3ce4e0f16c484a10bc5d8b9bd77fa662a12b85779a2d2d85d67ee0"}, + {file = "pycryptodome-3.21.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:5dfafca172933506773482b0e18f0cd766fd3920bd03ec85a283df90d8a17bc6"}, + {file = "pycryptodome-3.21.0-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:590ef0898a4b0a15485b05210b4a1c9de8806d3ad3d47f74ab1dc07c67a6827f"}, + {file = "pycryptodome-3.21.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f35e442630bc4bc2e1878482d6f59ea22e280d7121d7adeaedba58c23ab6386b"}, + {file = "pycryptodome-3.21.0-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff99f952db3db2fbe98a0b355175f93ec334ba3d01bbde25ad3a5a33abc02b58"}, + {file = "pycryptodome-3.21.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8acd7d34af70ee63f9a849f957558e49a98f8f1634f86a59d2be62bb8e93f71c"}, + {file = "pycryptodome-3.21.0.tar.gz", hash = "sha256:f7787e0d469bdae763b876174cf2e6c0f7be79808af26b1da96f1a64bcf47297"}, +] + [[package]] name = "pydantic" version = "2.9.2" @@ -2433,12 +2668,15 @@ type = ["pytest-mypy"] [extras] all = ["alembic", "anthropic", "fastapi", "groq", "sqlmodel", "uvicorn"] anthropic = ["anthropic"] +api-server = ["fastapi", "uvicorn"] groq = ["groq"] +minio = ["minio"] +mqtt = ["aiomqtt"] postgres = ["alembic", "psycopg2", "sqlmodel"] sqlite = ["alembic", "sqlmodel"] studio = ["alembic", "fastapi", "sqlmodel", "uvicorn"] [metadata] lock-version = "2.0" -python-versions = ">=3.9" -content-hash = "edacdfbe26e995a2e7ea2844c917e2a42c9f3e347b26a19c48405128ba9a187e" +python-versions = ">=3.9,<4.0" +content-hash = "87a490a476850291f82e56c14de0be4c97d9076270030f5e44b4111c581e1651" diff --git a/pyproject.toml b/pyproject.toml index d79e92555..034e5fa53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ include = [ ] [tool.poetry.dependencies] -python = ">=3.9" +python = ">=3.9,<4.0" numpy = ">=1.26.0" dill = "^0.3.8" colorama = "^0.4.6" @@ -41,13 +41,17 @@ psutil = "^5.9.0" anthropic = { version = "^0.34.2", optional = true } groq = { version = "^0.11.0", optional = true } # Storage +alembic = { version = "^1.14.0", optional = true } psycopg2 = { version = ">=2.7", optional = true } sqlmodel = { version = ">=0.0.21, <0.1.0", optional = true } +minio = { version = "7.2.10", optional = true } # Studio fastapi = { version = "^0.111.1", optional = true } uvicorn = { version = "^0.30.3", optional = true } +# Studio + API Server (optional) +aiomqtt = { version="^2.3.0", optional = true } + -alembic = { version = "^1.14.0", optional = true } [tool.poetry.group.dev.dependencies] pytest = "^8.3.2" sphinx = "<8.0.0" @@ -61,6 +65,9 @@ groq = ["groq"] sqlite = [ 'sqlmodel', 'alembic' ] postgres = ['sqlmodel', 'psycopg2', 'alembic'] studio = ['fastapi', 'uvicorn', 'sqlmodel', 'alembic'] +api-server = ["fastapi", "uvicorn"] +mqtt = ["aiomqtt"] +minio = ["minio"] all = [ "anthropic", "groq", diff --git a/src/ell/__version__.py b/src/ell/__version__.py index ffc218495..5fb3a6916 100644 --- a/src/ell/__version__.py +++ b/src/ell/__version__.py @@ -1,6 +1,6 @@ -try: - from importlib.metadata import version -except ImportError: - from importlib_metadata import version +from importlib.metadata import version, PackageNotFoundError -__version__ = version("ell-ai") +try: + __version__ = version("ell-ai") +except PackageNotFoundError: + __version__ = "unknown" \ No newline at end of file diff --git a/src/ell/api/__init__.py b/src/ell/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/ell/api/__main__.py b/src/ell/api/__main__.py new file mode 100644 index 000000000..f86ae22ed --- /dev/null +++ b/src/ell/api/__main__.py @@ -0,0 +1,82 @@ +import asyncio +import os +import uvicorn +import logging +from argparse import ArgumentParser + + +from ell.api.config import Config +from ell.api.server import create_app +from ell.api.logger import setup_logging + + + + +def main(): + log_level = int(os.environ.get("LOG_LEVEL", logging.INFO)) + setup_logging(level=log_level) + + parser = ArgumentParser(description="ell api") + parser.add_argument("--storage-dir", + type=str, + default=os.getenv("ELL_STORAGE_DIR"), + help="Storage directory (default: None, env: ELL_STORAGE_DIR)") + parser.add_argument("--pg-connection-string", + default=os.getenv("ELL_PG_CONNECTION_STRING"), + help="PostgreSQL connection string (default: None, env: ELL_PG_CONNECTION_STRING)") + parser.add_argument("--mqtt-connection-string", + default=os.getenv("ELL_MQTT_CONNECTION_STRING"), + help="MQTT connection string (default: None, env: ELL_MQTT_CONNECTION_STRING)") + parser.add_argument("--minio-endpoint", + default=os.getenv("ELL_MINIO_ENDPOINT"), + help="MinIO endpoint (default: None, env: ELL_MINIO_ENDPOINT)") + parser.add_argument("--minio-access-key", + default=os.getenv("ELL_MINIO_ACCESS_KEY"), + help="MinIO access key (default: None, env: ELL_MINIO_ACCESS_KEY)") + parser.add_argument("--minio-secret-key", + default=os.getenv("ELL_MINIO_SECRET_KEY"), + help="MinIO secret key (default: None, env: ELL_MINIO_SECRET_KEY)") + parser.add_argument("--minio-bucket", + default=os.getenv("ELL_MINIO_BUCKET"), + help="MinIO bucket (default: None, env: ELL_MINIO_BUCKET)") + parser.add_argument("--host", + default=os.getenv("ELL_API_HOST") or "0.0.0.0", + help="Host to run the server on (default: '0.0.0.0', env: ELL_API_HOST)") + parser.add_argument("--port", + type=int, + default=int(os.getenv("ELL_API_PORT") or 8081), + help="Port to run the server on (default: 8081, env: ELL_API_PORT)") + parser.add_argument("--dev", + action="store_true", + help="Run in development mode") + args = parser.parse_args() + + config = Config( + storage_dir=args.storage_dir, + pg_connection_string=args.pg_connection_string, + mqtt_connection_string=args.mqtt_connection_string, + minio_endpoint=args.minio_endpoint, + minio_access_key=args.minio_access_key, + minio_secret_key=args.minio_secret_key, + minio_bucket=args.minio_bucket, + ) + + app = create_app(config) + + loop = asyncio.new_event_loop() + + config = uvicorn.Config( + app=app, + host=args.host, + port=args.port, + loop=loop # type: ignore + ) + server = uvicorn.Server(config) + + loop.create_task(server.serve()) + + loop.run_forever() + + +if __name__ == "__main__": + main() diff --git a/src/ell/api/config.py b/src/ell/api/config.py new file mode 100644 index 000000000..6d95d8845 --- /dev/null +++ b/src/ell/api/config.py @@ -0,0 +1,38 @@ +# todo. move this under ell.api.server +import json +import os +from typing import Any, Optional +from pydantic import BaseModel + +import logging + +logger = logging.getLogger(__name__) + + +class Config(BaseModel): + storage_dir: Optional[str] = None + pg_connection_string: Optional[str] = None + mqtt_connection_string: Optional[str] = None + minio_endpoint: Optional[str] = None + minio_access_key: Optional[str] = None + minio_secret_key: Optional[str] = None + minio_bucket: Optional[str] = None + log_level: int = logging.INFO + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def model_post_init(self, __context: Any): + # Storage + # Enforce that we use either sqlite or postgres, but not both + if self.pg_connection_string is not None and self.storage_dir is not None: + raise ValueError("Cannot use both sqlite and postgres") + + # For now, fall back to sqlite if no PostgreSQL connection string is provided + if self.pg_connection_string is None and self.storage_dir is None: + # This intends to honor the default we had set in the CLI + # todo. better default? + self.storage_dir = os.getcwd() + + logger.info(f"Resolved config: {json.dumps(self.model_dump(exclude_none=True), indent=2)}") + diff --git a/src/ell/api/logger.py b/src/ell/api/logger.py new file mode 100644 index 000000000..c316b3b2e --- /dev/null +++ b/src/ell/api/logger.py @@ -0,0 +1,42 @@ +#todo. move under ell.api.server + +import logging +from colorama import Fore, Style, init + +initialized = False +def setup_logging(level: int = logging.INFO): + global initialized + if initialized: + return + # Initialize colorama for cross-platform colored output + init(autoreset=True) + + # Create a custom formatter + class ColoredFormatter(logging.Formatter): + FORMATS = { + logging.DEBUG: Fore.CYAN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.INFO: Fore.GREEN + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.WARNING: Fore.YELLOW + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.ERROR: Fore.RED + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL, + logging.CRITICAL: Fore.RED + Style.BRIGHT + "[%(asctime)s] %(levelname)-8s %(name)s: %(message)s" + Style.RESET_ALL + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, datefmt="%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + # Create and configure the logger + logger = logging.getLogger("ell") + logger.setLevel(level) + + # Create console handler and set formatter + console_handler = logging.StreamHandler() + console_handler.setFormatter(ColoredFormatter()) + + # Add the handler to the logger + logger.addHandler(console_handler) + + initialized = True + + return logger \ No newline at end of file diff --git a/src/ell/api/pubsub/__init__.py b/src/ell/api/pubsub/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/ell/api/pubsub/abc.py b/src/ell/api/pubsub/abc.py new file mode 100644 index 000000000..04551ad73 --- /dev/null +++ b/src/ell/api/pubsub/abc.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +import logging +from typing import List + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + +Subscriber = WebSocket + + +class PubSub(ABC): + @abstractmethod + async def publish(self, topic: str, message: str) -> None: + pass + + @abstractmethod + def subscribe(self, topic: str, subscriber: Subscriber) -> None: + pass + + @abstractmethod + async def subscribe_async(self, topic: str, subscriber: Subscriber) -> None: + pass + + @abstractmethod + def unsubscribe(self, topic: str, subscriber: Subscriber): + pass + + @abstractmethod + def unsubscribe_from_all(self, subscriber: Subscriber): + pass + + + @abstractmethod + def get_subscriptions(self, subscriber: Subscriber) -> List[str]: + pass diff --git a/src/ell/api/pubsub/mqtt.py b/src/ell/api/pubsub/mqtt.py new file mode 100644 index 000000000..215bf0f3a --- /dev/null +++ b/src/ell/api/pubsub/mqtt.py @@ -0,0 +1,92 @@ +import asyncio +import json +import logging + +import aiomqtt + +from ell.api.pubsub.abc import Subscriber +from ell.api.pubsub.websocket import WebSocketPubSub + +logger = logging.getLogger(__name__) + + +class MqttWebSocketPubSub(WebSocketPubSub): + mqtt_client: aiomqtt.Client + + def __init__(self, conn: aiomqtt.Client): + super().__init__() + self.mqtt_client = conn + + def listen(self, loop: asyncio.AbstractEventLoop): + self.listener = loop.create_task(self._relay_all()) + return self.listener + + async def publish(self, topic: str, message: str) -> None: + # this is a bit sus because we could get in a loop if the message is echoed back + # we're also publishing to mqtt, not websocket clients + await self.mqtt_client.publish(topic, message) + + async def _relay_all(self) -> None: + """ + Relays all messages received on the subscribed MQTT topics to the websocket subscribers on the same topics. + + Example: + self.subscribe("detailed-telemetry/#") # <- Registers us to receive MQTT messages published to detailed-telemetry/1, detailed-telemetry/2, ... + + Upon receipt, we forward these messages to any connected Ell Studio websockets whose subscription matches the published topic . + + i.e.: + Subscriptions map: + "detailed-telemetry/1" -> [socket1] + "detailed-telemetry/2" -> [socket2] + "lmp/#" -> [socket1, socket2] + - An MQTT message published to detailed-telemetry/1 will be relayed to socket1 + - An MQTT message published to lmp/42 will be relayed to socket1 and socket2 + + + """ + logger.info("Starting mqtt listener") + async for message in self.mqtt_client.messages: + try: + logger.debug(f"Received message on topic {message.topic}: {message.payload}") + # Call the websocket's publish method to publish the message received from MQTT to the websocket + await super().publish(str(message.topic), json.loads( + message.payload # type: ignore + )) + except Exception as e: + logger.error(f"Error relaying message: {e}") + + async def subscribe_async(self, topic: str, subscriber: Subscriber) -> None: + await self.mqtt_client.subscribe(topic) + super().subscribe(topic, subscriber) + + +async def setup( + mqtt_connection_string: str, + retry_interval_seconds: int = 1, + retry_max_attempts: int = 5 +) -> tuple[MqttWebSocketPubSub, aiomqtt.Client]: # type: ignore + """ + Connect to the MQTT broker at `mqtt_connection_string` using the provided retry policy. + Returns the client and the open connection which should be handled by an AsyncExitStack or similar. + """ + for attempt in range(retry_max_attempts): + try: + host, port = mqtt_connection_string.split("://")[1].split(":") + logger.info(f"Connecting to MQTT broker at {host}:{port}") + + # Create the client - it will connect when used as context manager + mqtt_client = aiomqtt.Client(hostname=host, port=int(port) if port else 1883) + # We call __aenter__ here in order to connect and retry on failure + # The client is passed back and must be handled with __aclose__() + await mqtt_client.__aenter__() + return MqttWebSocketPubSub(mqtt_client), mqtt_client + + except aiomqtt.MqttError as e: + logger.error(f"Failed to connect to MQTT [Attempt {attempt + 1}/{retry_max_attempts}]: {e}") + if attempt < retry_max_attempts - 1: + await asyncio.sleep(retry_interval_seconds) + continue + else: + logger.error("Max retry attempts reached. Unable to connect to MQTT.") + raise ValueError(f"Failed to connect to MQTT after {retry_max_attempts} attempts") from e diff --git a/src/ell/api/pubsub/topic.py b/src/ell/api/pubsub/topic.py new file mode 100644 index 000000000..e24c578b4 --- /dev/null +++ b/src/ell/api/pubsub/topic.py @@ -0,0 +1,82 @@ +from functools import lru_cache +from typing import Optional + +MAX_TOPIC_LENGTH = 65535 + +class TopicMatcher: + def __init__(self): + # Cache validation and matching results + self._validate_publish_topic = lru_cache(maxsize=1024)(self._validate_publish_topic_impl) + self._validate_subscription_pattern = lru_cache(maxsize=1024)(self._validate_subscription_pattern_impl) + self.matches = lru_cache(maxsize=4096)(self.matches) + + def _validate_publish_topic_impl(self, topic: str) -> tuple[bool, Optional[str]]: + """Internal implementation that returns (is_valid, error_message)""" + if not topic: + return False, "Topic cannot be empty" + if len(topic) > MAX_TOPIC_LENGTH: + return False, f"Topic exceeds maximum length of {MAX_TOPIC_LENGTH}" + if "#" in topic or "+" in topic: + return False, "Publish topics cannot contain wildcards (# or +)" + return True, None + + def _validate_subscription_pattern_impl(self, pattern: str) -> tuple[bool, Optional[str]]: + """Internal implementation that returns (is_valid, error_message)""" + if not pattern: + return False, "Subscription pattern cannot be empty" + if len(pattern) > MAX_TOPIC_LENGTH: + return False, f"Pattern exceeds maximum length of {MAX_TOPIC_LENGTH}" + if "#/" in pattern: + return False, "Multi-level wildcard (#) must be the last character in the pattern" + + for level in pattern.split("/"): + if len(level) > 1: + if "+" in level: + return False, "Single-level wildcard (+) must be alone in its level" + if "#" in level: + return False, "Multi-level wildcard (#) must be alone in its level" + return True, None + + def validate_publish_topic(self, topic: str) -> None: + """Public method that raises ValueError with specific message if invalid""" + is_valid, error = self._validate_publish_topic(topic) + if not is_valid: + raise ValueError(error) + + def validate_subscription_pattern(self, pattern: str) -> None: + """Public method that raises ValueError with specific message if invalid""" + is_valid, error = self._validate_subscription_pattern(pattern) + if not is_valid: + raise ValueError(error) + + def matches(self, topic: str, pattern: str) -> bool: + """Check if a topic matches a wildcard pattern.""" + # Use cached validation methods + self.validate_publish_topic(topic) + self.validate_subscription_pattern(pattern) + + topic_parts = topic.split("/") + pattern_parts = pattern.split("/") + + # Handle shared subscriptions + if pattern_parts[0] == "$share": + pattern_parts = pattern_parts[2:] + + def match_parts(t_parts: list[str], p_parts: list[str]) -> bool: + if not t_parts: + return not p_parts or p_parts[0] == "#" + if not p_parts: + return False + if p_parts[0] == "#": + return True + if p_parts[0] == "+" or t_parts[0] == p_parts[0]: + return match_parts(t_parts[1:], p_parts[1:]) + return False + + return match_parts(topic_parts, pattern_parts) + +matcher = TopicMatcher() +topic_matches = matcher.matches +validate_publish_topic = matcher.validate_publish_topic +validate_subscription_pattern = matcher.validate_subscription_pattern + diff --git a/src/ell/api/pubsub/websocket.py b/src/ell/api/pubsub/websocket.py new file mode 100644 index 000000000..070511d44 --- /dev/null +++ b/src/ell/api/pubsub/websocket.py @@ -0,0 +1,74 @@ +import asyncio +from typing import Any, List + +from ell.api.pubsub.abc import PubSub, Subscriber, logger +from ell.api.pubsub.topic import validate_publish_topic, topic_matches, validate_subscription_pattern + + +class WebSocketPubSub(PubSub): + def __init__(self): + # Topic pattern -> subscribed websockets + self.subscriptions: dict[str, list[Subscriber]] = {} + # Reverse index for self.subscriptions (websocket -> their subscribed topic patterns) + self.subscribers: dict[Subscriber, list[str]] = {} + + async def publish(self, topic: str, message: Any): + validate_publish_topic(topic) + # Notify all subscribers whose subscription pattern is a match for `topic` + subscriptions = self.subscriptions.copy() # copy to avoid mutating while iterating + logger.info(f"Relaying message to socket {topic} subscribers") + for pattern in subscriptions: + if topic_matches(topic, pattern): + for subscriber in subscriptions[pattern]: + asyncio.create_task(subscriber.send_json( + {"topic": topic, "message": message})) + + def subscribe(self, topic_pattern: str, subscriber: Subscriber) -> None: + """Subscribes the websocket `subscriber` to receive messages matching the topic pattern `topic`""" + validate_subscription_pattern(topic_pattern) + logger.info(f"Subscribing ws {subscriber} to {topic_pattern}") + # Add the subscriber to the list for the topic + if topic_pattern not in self.subscriptions: + self.subscriptions[topic_pattern] = [] + self.subscriptions[topic_pattern].append(subscriber) + if subscriber not in self.subscribers: + self.subscribers[subscriber] = [] + self.subscribers[subscriber].append(topic_pattern) + + def unsubscribe(self, topic: str, subscriber: Subscriber): + """Unsubscribes the websocket `subscriber` from the topic pattern `topic`""" + subscriptions = self.subscriptions.copy() + if topic in subscriptions: + # Try to apply the edit to the original subscriptions map + try: + # Remove the subscriber + self.subscriptions[topic].remove(subscriber) + # Prune the topic from the subscriptions map if the edit resulted in 0 subscribers for the topic + if not self.subscriptions[topic]: + del self.subscriptions[topic] + except Exception: + # If anything goes wrong in updating the subscriptions map, we assume it's concurrency-related + # and the current subscriptions map contains the edit we would have made + pass + + def unsubscribe_from_all(self, subscriber: Subscriber): + """Removes the websocket `subscriber` from all topics. Typically called on socket disconnect.""" + subscribers = self.subscribers.copy() + subscriber_subscriptions = subscribers[subscriber] + if subscriber_subscriptions: + for topic in subscriber_subscriptions: + self.unsubscribe(topic, subscriber) + try: + del self.subscribers[subscriber] + except KeyError: + pass + + async def subscribe_async(self, topic_pattern: str, subscriber: Subscriber) -> None: + """Subscribes the websocket `subscriber` to receive messages matching the topic pattern `topic`""" + validate_subscription_pattern(topic_pattern) + logger.info(f"Subscribing ws {subscriber} to {topic_pattern}") + self.subscribe(topic_pattern, subscriber) + + def get_subscriptions(self, subscriber: Subscriber) -> List[str]: + """Returns the list of topic patterns that the websocket `subscriber` is subscribed to""" + return self.subscribers.get(subscriber, []) diff --git a/src/ell/api/server.py b/src/ell/api/server.py new file mode 100644 index 000000000..c5fb54f9e --- /dev/null +++ b/src/ell/api/server.py @@ -0,0 +1,169 @@ +# todo. under ell.api.server.___main___ +import asyncio +from contextlib import asynccontextmanager, AsyncExitStack +import logging +from typing import List, Optional + +from fastapi import Depends, FastAPI, HTTPException + +from ell.api.config import Config +from ell.api.pubsub.abc import PubSub +from ell.serialize.serializer import get_async_serializer +from ell.serialize.config import SerializeConfig +from ell.serialize.protocol import EllAsyncSerializer +from ell.types.serialize import GetLMPOutput, LMPInvokedEvent, WriteInvocationInput, WriteLMPInput, LMP, WriteBlobInput +from ell.util.errors import missing_ell_extras + +logger = logging.getLogger(__name__) + +pubsub: Optional[PubSub] = None + + +async def get_pubsub(): + yield pubsub + +async def init_pubsub(config: Config, exit_stack: AsyncExitStack): + """Set up the appropriate pubsub client based on configuration.""" + if config.mqtt_connection_string is not None: + try: + from ell.api.pubsub.mqtt import setup + except ImportError as e: + raise missing_ell_extras( + message="Received mqtt_connection_string but dependencies missing.", + extras=["mqtt"] + ) from e + + pubsub, mqtt_client = await setup(config.mqtt_connection_string) + + exit_stack.push_async_exit(mqtt_client) + + loop = asyncio.get_event_loop() + return pubsub, pubsub.listen(loop) + + return None, None + + + +serializer: Optional[EllAsyncSerializer] = None + + +def init_serializer(config: Config) -> EllAsyncSerializer: + global serializer + if serializer is not None: + return serializer + serializer = get_async_serializer(config=SerializeConfig( + **config.model_dump() + )) + + return serializer + + +def get_serializer(): + if serializer is None: + raise ValueError("Serializer not initialized") + return serializer + + +def create_app(config: Config): + # setup_logging(config.log_level) + + @asynccontextmanager + async def lifespan(app: FastAPI): + global serializer + global pubsub + exit_stack = AsyncExitStack() + pubsub_task = None + + logger.info("Starting lifespan") + + serializer = init_serializer(config) + + try: + pubsub, pubsub_task = await init_pubsub(config, exit_stack) + yield + + finally: + if pubsub_task and not pubsub_task.done(): + pubsub_task.cancel() + try: + await pubsub_task + except asyncio.CancelledError: + pass + + await exit_stack.aclose() + pubsub = None + + app = FastAPI( + title="ell api", + description="ell api server", + version="0.1.0", + lifespan=lifespan + ) + + @app.get("/lmp/versions", response_model=List[LMP]) + async def get_lmp_versions( + fqn: str, + serializer: EllAsyncSerializer = Depends(get_serializer)): + return await serializer.get_lmp_versions(fqn) + + @app.get("/lmp/{lmp_id}", response_model=GetLMPOutput) + async def get_lmp(lmp_id: str, serializer: EllAsyncSerializer = Depends(get_serializer)): + lmp = await serializer.get_lmp(lmp_id=lmp_id) + if lmp is None: + raise HTTPException(status_code=404, detail="LMP not found") + return lmp + + @app.post("/lmp") + async def write_lmp( + lmp: WriteLMPInput, + pubsub: PubSub = Depends(get_pubsub), + serializer: EllAsyncSerializer = Depends(get_serializer) + ): + await serializer.write_lmp(lmp) + + if pubsub: + loop = asyncio.get_event_loop() + loop.create_task( + pubsub.publish( + f"lmp/{lmp.lmp_id}/created", + lmp.model_dump_json(exclude_none=True, exclude_unset=True), + ) + ) + + @app.post("/invocation", response_model=WriteInvocationInput) + async def write_invocation( + input: WriteInvocationInput, + pubsub: PubSub = Depends(get_pubsub), + serializer: EllAsyncSerializer = Depends(get_serializer) + ): + logger.info(f"Writing invocation {input.invocation.lmp_id}") + # TODO: return anything this might create like invocation id + result = await serializer.write_invocation(input) + + if pubsub: + loop = asyncio.get_event_loop() + loop.create_task( + pubsub.publish( + f"lmp/{input.invocation.lmp_id}/invoked", + LMPInvokedEvent( + lmp_id=input.invocation.lmp_id, + # invocation_id=invo.id, + # todo. return data from write invocation + consumes=[] + ).model_dump_json() + ) + ) + + return input + + @app.post("/blob") + async def store_blob( + input: WriteBlobInput, + serializer: EllAsyncSerializer = Depends(get_serializer) + ): + if not serializer.supports_blobs: + raise HTTPException(status_code=400, detail="Blob support is not enabled.") + return await serializer.store_blob(**input.model_dump()) + + + return app diff --git a/src/ell/configurator.py b/src/ell/configurator.py index 0537d03b2..38908d98b 100644 --- a/src/ell/configurator.py +++ b/src/ell/configurator.py @@ -8,8 +8,13 @@ from ell.provider import Provider from dataclasses import dataclass, field +from ell.serialize.config import SerializeConfig +from ell.serialize.protocol import EllSerializer +from ell.serialize.serializer import get_serializer +from ell.util.errors import missing_ell_extras + if TYPE_CHECKING: - from ell.stores import Store + from ell.stores.store import Store else: Store = None @@ -34,56 +39,60 @@ class _Model: class Config(BaseModel): """Configuration class for ELL.""" - + model_config = ConfigDict( arbitrary_types_allowed=True, protected_namespaces=('protect_', ) # Override protected namespaces ) registry: Dict[str, _Model] = Field( - default_factory=dict, + default_factory=dict, description="A dictionary mapping model names to their configurations." ) verbose: bool = Field( - default=False, + default=False, description="If True, enables verbose logging." ) wrapped_logging: bool = Field( - default=True, + default=True, description="If True, enables wrapped logging for better readability." ) override_wrapped_logging_width: Optional[int] = Field( - default=None, + default=None, description="If set, overrides the default width for wrapped logging." ) store: Optional[Store] = Field( - default=None, + default=None, description="An optional Store instance for persistence." ) autocommit: bool = Field( - default=False, + default=False, description="If True, enables automatic committing of changes to the store." ) lazy_versioning: bool = Field( - default=True, + default=True, description="If True, enables lazy versioning for improved performance." ) default_api_params: Dict[str, Any] = Field( - default_factory=dict, + default_factory=dict, description="Default parameters for language models." ) default_client: Optional[openai.Client] = Field( - default=None, + default=None, description="The default OpenAI client used when a specific model client is not found." ) autocommit_model: str = Field( - default="gpt-4o-mini", + default="gpt-4o-mini", description="When set, changes the default autocommit model from GPT 4o mini." ) providers: Dict[Type, Provider] = Field( - default_factory=dict, + default_factory=dict, description="A dictionary mapping client types to provider classes." ) + serializer: EllSerializer = Field( + default=None, + description="Serializer used for LMPs and invocations" + ) def __init__(self, **data): super().__init__(**data) self._lock = threading.Lock() @@ -195,7 +204,9 @@ def init( lazy_versioning: bool = True, default_api_params: Optional[Dict[str, Any]] = None, default_client: Optional[Any] = None, - autocommit_model: str = "gpt-4o-mini" + autocommit_model: str = "gpt-4o-mini", + api_url: Optional[str] = None, + serializer: Optional[EllSerializer] = None, ) -> None: """ Initialize the ELL configuration with various settings. @@ -214,19 +225,38 @@ def init( :type default_openai_client: openai.Client, optional :param autocommit_model: Set the model used for autocommitting. :type autocommit_model: str + :param api_server_url: Ell API server URL + :type api_server_url: str + :param serializer: Ell serializer class. + :type serializer: EllSerializer """ # XXX: prevent double init config.verbose = verbose config.lazy_versioning = lazy_versioning - if isinstance(store, str): + if store and not isinstance(store, str): try: - from ell.stores.sql import SQLiteStore - config.store = SQLiteStore(store) + from ell.serialize.sql import SQLSerializer + config.serializer = SQLSerializer(store) + config.store = config.serializer.store # legacy except ImportError: - raise ImportError("Failed importing SQLiteStore. Install with `pip install -U ell-ai[all]`. More info: https://docs.ell.so/installation") + raise missing_ell_extras( + message="Failed importing SQL store dependencies", + extras=["all"] + ) else: - config.store = store + if serializer is not None: + config.serializer = serializer + else: + serialize_config = SerializeConfig( + api_url=api_url, + storage_dir=store, + # ...other options + log_level=20 if verbose else 0, + ) + if serialize_config.is_enabled: + config.serializer = get_serializer(serialize_config) + config.autocommit = autocommit or config.autocommit if default_api_params is not None: diff --git a/src/ell/lmp/_track.py b/src/ell/lmp/_track.py index 0799adcc8..0eb5cd0b0 100644 --- a/src/ell/lmp/_track.py +++ b/src/ell/lmp/_track.py @@ -1,7 +1,10 @@ import json import logging import threading + +from ell.types import Message, ContentBlock, ToolResult from ell.types.lmp import LMPType +from ell.types.serialize import Invocation, InvocationContents, WriteInvocationInput, utc_now, WriteLMPInput from ell.util._warnings import _autocommit_warning import ell.util.closure from ell.configurator import config @@ -16,10 +19,6 @@ from ell.util.serialization import compute_state_cache_key from ell.util.serialization import prepare_invocation_params -try: - from ell.stores.models.core import SerializedLMP, Invocation, InvocationContents -except ImportError: - SerializedLMP = Invocation = InvocationContents = None logger = logging.getLogger(__name__) @@ -57,7 +56,7 @@ def _track( func_to_track.__ell_force_closure__ = lambda: ell.util.closure.lexically_closured_source(func_to_track, forced_dependencies) if not hasattr(func_to_track, "__ell_hash__") and not config.lazy_versioning: func_to_track.__ell_force_closure__() - + @wraps(func_to_track) def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: @@ -66,7 +65,7 @@ def tracked_func(*fn_args, _get_invocation_id=False, **fn_kwargs) -> str: invocation_id = "invocation-" + secrets.token_hex(16) state_cache_key: str = None - if not config.store: + if not config.serializer: res = func_to_track( *fn_args, **fn_kwargs, _invocation_origin=invocation_id )[0] @@ -216,7 +215,7 @@ def serialize_lmp(func): # print(name) api_params = getattr(func, "__ell_api_params__", None) - lmps = config.store.get_versions_by_fqn(fqn=name) + lmps = config.serializer.get_lmp_versions(fqn=name) version = 0 already_in_store = any(lmp.lmp_id == func.__ell_hash__ for lmp in lmps) @@ -237,7 +236,7 @@ def serialize_lmp(func): )[0] ) - serialized_lmp = SerializedLMP( + serialized_lmp = WriteLMPInput( lmp_id=func.__ell_hash__, name=name, created_at=utc_now(), @@ -249,46 +248,55 @@ def serialize_lmp(func): lmp_type=lmp_type, api_params=api_params if api_params else None, version_number=version, + uses=[f.__ell_hash__ for f in func.__ell_uses__], ) - config.store.write_lmp( - serialized_lmp, [f.__ell_hash__ for f in func.__ell_uses__] - ) + config.serializer.write_lmp(serialized_lmp) func._has_serialized_lmp = True return func - def _write_invocation( - func, - invocation_id, - latency_ms, - prompt_tokens, - completion_tokens, - state_cache_key, - invocation_api_params, - cleaned_invocation_params, - consumes, - result, - parent_invocation_id, + func, + invocation_id, + latency_ms, + prompt_tokens, + completion_tokens, + state_cache_key, + invocation_api_params, + cleaned_invocation_params, + consumes, + result, + parent_invocation_id ): + # print(result) + # todo(alex). figure out what's going on here, looks like we're getting result as a tool result / single message sometimes + + results = None + if isinstance(result, list): + results = result + elif isinstance(result, ToolResult): + results = [Message(role='tool', content=[ContentBlock(tool_result=result)])] + else: + results = [result] + invocation_contents = InvocationContents( invocation_id=invocation_id, params=cleaned_invocation_params, - results=result, + results=results, invocation_api_params=invocation_api_params, global_vars=get_immutable_vars(func.__ell_closure__[2]), free_vars=get_immutable_vars(func.__ell_closure__[3]), ) - if invocation_contents.should_externalize and config.store.has_blob_storage: + if invocation_contents.should_externalize and config.serializer.supports_blobs: invocation_contents.is_external = True - - # Write to the blob store - blob_id = config.store.blob_store.store_blob( - json.dumps( - invocation_contents.model_dump(), default=str, ensure_ascii=False - ).encode("utf-8"), - invocation_id, + + # Write to the blob store + blob_id = config.serializer.store_blob( + blob_id=invocation_id, + #todo(alex): normalize serialization + blob=json.dumps(invocation_contents.model_dump( + ), default=str, ensure_ascii=False).encode('utf-8'), ) invocation_contents = InvocationContents( invocation_id=invocation_id, @@ -307,4 +315,4 @@ def _write_invocation( contents=invocation_contents, ) - config.store.write_invocation(invocation, consumes) + config.serializer.write_invocation(WriteInvocationInput(invocation=invocation, consumes=consumes)) diff --git a/src/ell/lmp/tool.py b/src/ell/lmp/tool.py index fd6c074f9..4abc32685 100644 --- a/src/ell/lmp/tool.py +++ b/src/ell/lmp/tool.py @@ -97,7 +97,7 @@ def wrapper( # Determine the type annotation if param.annotation == inspect.Parameter.empty: - raise ValueError(f"Parameter {param_name} has no type annotation, and cannot be converted into a tool schema for OpenAI and other provisders. Should OpenAI produce a string or an integer, etc, for this parameter?") + raise ValueError(f"Parameter {param_name} has no type annotation, and cannot be converted into a tool schema for OpenAI and other providers. Should OpenAI produce a string or an integer, etc, for this parameter?") annotation = param.annotation # Determine the default value diff --git a/src/ell/providers/anthropic.py b/src/ell/providers/anthropic.py index 2002f0ea8..ba379b09b 100644 --- a/src/ell/providers/anthropic.py +++ b/src/ell/providers/anthropic.py @@ -1,4 +1,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, cast + +from pydantic import BaseModel + from ell.provider import EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall, ImageContent @@ -194,7 +197,7 @@ def _content_block_to_anthropic_format(content_block: ContentBlock): type="tool_use", id=tool_call.tool_call_id, name=tool_call.tool.__name__, - input=tool_call.params.model_dump() + input=tool_call.serialize_params(), ) elif (tool_result := content_block.tool_result): return dict( diff --git a/src/ell/providers/bedrock.py b/src/ell/providers/bedrock.py index cb3674556..6178a5d23 100644 --- a/src/ell/providers/bedrock.py +++ b/src/ell/providers/bedrock.py @@ -1,6 +1,9 @@ from abc import ABC, abstractmethod from collections import defaultdict from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast + +from pydantic import BaseModel + from ell.provider import EllCallParams, Metadata, Provider from ell.types import Message, ContentBlock, ToolCall, ImageContent from ell.types._lstr import _lstr @@ -200,7 +203,7 @@ def content_block_to_bedrock_format(content_block: ContentBlock) -> Dict[str, An "toolUse": { "toolUseId": content_block.tool_call.tool_call_id, "name": content_block.tool_call.tool.__name__, - "input": content_block.tool_call.params.model_dump() + "input": content_block.tool_call.serialize_params(), } } elif content_block.tool_result: diff --git a/src/ell/providers/openai.py b/src/ell/providers/openai.py index 8fe7f4d1b..ef85ff416 100644 --- a/src/ell/providers/openai.py +++ b/src/ell/providers/openai.py @@ -64,7 +64,7 @@ def translate_to_provider(self, ell_call : EllCallParams) -> Dict[str, Any]: type="function", function=dict( name=tool_call.tool.__name__, - arguments=json.dumps(tool_call.params.model_dump(), ensure_ascii=False) + arguments=json.dumps(tool_call.serialize_params(), ensure_ascii=False) ) ) for tool_call in tool_calls ], role="assistant", diff --git a/src/ell/serialize/__init__.py b/src/ell/serialize/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/ell/serialize/config.py b/src/ell/serialize/config.py new file mode 100644 index 000000000..0d0a9c504 --- /dev/null +++ b/src/ell/serialize/config.py @@ -0,0 +1,33 @@ +import json +from typing import Any, Optional +from pydantic import BaseModel, Field, computed_field + +import logging + +logger = logging.getLogger(__name__) + + +class SerializeConfig(BaseModel): + storage_dir: Optional[str] = Field(default=None, description="Filesystem path used for SQLite and local blob storage") + api_url: Optional[str] = Field(default=None, description="ell API server endpoint") + pg_connection_string: Optional[str] = None + minio_endpoint: Optional[str] = None + minio_access_key: Optional[str] = None + minio_secret_key: Optional[str] = None + minio_bucket: Optional[str] = None + log_level: int = logging.INFO + + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def model_post_init(self, __context: Any): + # Enforce that we use 1 storage backend (for now) + if self.pg_connection_string is not None and self.storage_dir is not None: + raise ValueError("Cannot use both sqlite and postgres") + logger.debug(f"Resolved config: {json.dumps(self.model_dump(exclude_none=True), indent=2)}") + + @computed_field + def is_enabled(self) -> bool: + return bool(self.api_url or self.pg_connection_string or self.storage_dir or self.minio_endpoint) + diff --git a/src/ell/serialize/http.py b/src/ell/serialize/http.py new file mode 100644 index 000000000..4744f826a --- /dev/null +++ b/src/ell/serialize/http.py @@ -0,0 +1,260 @@ +import logging +from typing import List, Optional, Dict, Any + +import httpx +from httpx import HTTPStatusError + +from ell.serialize.protocol import EllAsyncSerializer, EllSerializer +from ell.types.serialize import GetLMPOutput, WriteLMPInput, LMP, WriteInvocationInput + + +# tood. make sure we don't lose any information or end up with malformed stuff relative to what +# the sto4res have been doing for serialization (this function) +# this should probably just be handled by the serialization types to centralize serialization code in one place +# def to_json(obj): +# """Serializes ell objects to json for writing to the database or wire protocols""" +# return json.dumps( +# pydantic_ltype_aware_cattr.unstructure(obj), +# sort_keys=True, default=repr, ensure_ascii=False) + +def make_handle_http_error(logger: logging.Logger): + def handle_http_error( + error: HTTPStatusError, + span: str, + message: Optional[str] = None, + extra: Optional[Dict[str, Any]] = None + ) -> None: + if error.response.status_code == 422: + error_detail = error.response.json().get( + "detail", "No detailed error message provided") + logger.error( + message or f"HTTP {error.response.status_code} Error in {span}", + extra={ + **(extra or {}), + "status_code": error.response.status_code, + "error_detail": error_detail, + "span": span, + "url": str(error.response.url), + "request_id": error.response.headers.get("x-request-id"), + } + ) + raise ValueError(f"Invalid input: {error_detail}") from error + raise + + return handle_http_error + + +class EllHTTPSerializer(EllSerializer): + def __init__(self, base_url: Optional[str] = None, client: Optional[httpx.Client] = None): + assert base_url is not None or client is not None, "Either base_url or client must be provided" + self.base_url = base_url + self.client = client or httpx.Client(base_url=base_url) # type: ignore + self.supports_blobs = True # we assume the server does, if not will find out later + self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__) + self._handle_http_error = make_handle_http_error(self.logger) + + def get_lmp(self, lmp_id: str) -> GetLMPOutput: + try: + response = self.client.get(f"/lmp/{lmp_id}") + response.raise_for_status() + data = response.json() + return None if data is None else LMP(**data) + except HTTPStatusError as e: + self._handle_http_error(error=e, span="get_lmp", message="Failed to get LMP", extra={lmp_id: lmp_id}) + raise + + def write_lmp(self, lmp: WriteLMPInput) -> None: + try: + response = self.client.post("/lmp", + headers={"Content-Type": "application/json"}, + content=lmp.model_dump_json(exclude_none=True, exclude_unset=True)) + response.raise_for_status() + except HTTPStatusError as e: + self._handle_http_error( + message="Failed to write LMP", + span="write_lmp", + error=e, + extra={'lmp_id': lmp.lmp_id, 'lmp_version': lmp.version_number} + ) + raise + + def write_invocation(self, input: WriteInvocationInput) -> None: + try: + response = self.client.post( + url="/invocation", + headers={"Content-Type": "application/json"}, + content=input.model_dump_json(exclude_none=True, exclude_unset=True), + ) + response.raise_for_status() + return None + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="write_invocation", + message="Failed to write invocation", + extra={'invocation_id': input.invocation.id} + ) + raise + + def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + try: + response = self.client.post("/blob", data={ + "blob_id": blob_id, + "blob": blob, + "metadata": metadata + }) + response.raise_for_status() + return response.json()["blob_id"] + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="store_blob", + message="Failed to store blob", + extra={'blob_id': blob_id} + ) + raise + + def retrieve_blob(self, blob_id: str) -> bytes: + try: + response = self.client.get(f"/blob/{blob_id}") + response.raise_for_status() + return response.content + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="retrieve_blob", + message="Failed to retrieve blob", + extra={'blob_id': blob_id} + ) + raise + + def close(self): + self.client.close() + + def get_lmp_versions(self, fqn: str) -> List[LMP]: + try: + response = self.client.get("/lmp/versions", params={"fqn": fqn}) + response.raise_for_status() + data = response.json() + return [LMP(**lmp_data) for lmp_data in data] + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="get_lmp_versions", + message="Failed to get LMP versions", + extra={'fqn': fqn} + ) + raise + + +class EllAsyncHTTPSerializer(EllAsyncSerializer): + def __init__(self, base_url: str): + self.base_url = base_url + self.client = httpx.AsyncClient(base_url=base_url) + self.supports_blobs = True # we assume the server does, if not will find out later + self.logger = logging.getLogger( + __name__).getChild(self.__class__.__name__) + self._handle_http_error = make_handle_http_error(self.logger) + + async def get_lmp(self, lmp_id: str) -> GetLMPOutput: + try: + response = await self.client.get(f"/lmp/{lmp_id}") + response.raise_for_status() + data = response.json() + if data is None: + return None + return LMP(**data) + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="get_lmp", + message="Failed to get LMP", + extra={'lmp_id': lmp_id} + ) + raise + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + try: + response = await self.client.get("/lmp/versions", params={"fqn": fqn}) + response.raise_for_status() + data = response.json() + return [LMP(**lmp_data) for lmp_data in data] + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="get_lmp_versions", + message="Failed to get LMP versions", + extra={'fqn': fqn} + ) + raise + + async def write_lmp(self, lmp: WriteLMPInput, uses: List[str]) -> None: + try: + response = await self.client.post("/lmp", json={ + "lmp": lmp.model_dump(mode="json", exclude_none=True, exclude_unset=True), + "uses": uses + }) + response.raise_for_status() + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="write_lmp", + message="Failed to write LMP", + extra={'lmp_id': lmp.lmp_id, 'lmp_version': lmp.version_number} + ) + raise + + async def write_invocation(self, input: WriteInvocationInput) -> None: + try: + response = await self.client.post( + "/invocation", + headers={"Content-Type": "application/json"}, + content=input.model_dump_json(exclude_none=True, exclude_unset=True) + ) + response.raise_for_status() + return None + except HTTPStatusError as e: + self._handle_http_error(message="Failed to write invocation", span="write_invocation", error=e, + extra={'invocation_id': input.invocation.id}) + raise + + async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + try: + response = await self.client.post("/blob", data={ + "blob_id": blob_id, + "blob": blob, + "metadata": metadata + }) + response.raise_for_status() + return response.json()["blob_id"] + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="store_blob", + message="Failed to store blob", + extra={'blob_id': blob_id} + ) + raise + + async def retrieve_blob(self, blob_id: str) -> bytes: + try: + response = await self.client.get(f"/blob/{blob_id}") + response.raise_for_status() + return response.content + except HTTPStatusError as e: + self._handle_http_error( + error=e, + span="retrieve_blob", + message="Failed to retrieve blob", + extra={'blob_id': blob_id} + ) + raise + + async def close(self): + await self.client.aclose() + + async def __aenter__(self): + return self + + async def __aexit__(self): + await self.close() diff --git a/src/ell/serialize/postgres.py b/src/ell/serialize/postgres.py new file mode 100644 index 000000000..82a0da83a --- /dev/null +++ b/src/ell/serialize/postgres.py @@ -0,0 +1,16 @@ +from typing import Optional + +from ell.serialize.sql import SQLSerializer, AsyncSQLSerializer +from ell.stores.sql import PostgresStore +from ell.stores.store import BlobStore, AsyncBlobStore + + +class PostgresSerializer(SQLSerializer): + def __init__(self, db_uri: str, blob_store: Optional[BlobStore] = None): + super().__init__(PostgresStore(db_uri, blob_store)) + + +# todo(async): the underlying store is not async-aware +class AsyncPostgresSerializer(AsyncSQLSerializer): + def __init__(self, db_uri: str, blob_store: Optional[AsyncBlobStore] = None): + super().__init__(PostgresStore(db_uri, blob_store)) diff --git a/src/ell/serialize/protocol.py b/src/ell/serialize/protocol.py new file mode 100644 index 000000000..2da1a3a44 --- /dev/null +++ b/src/ell/serialize/protocol.py @@ -0,0 +1,55 @@ +from typing import Protocol, Optional, List, Dict, Any, runtime_checkable + +from ell.types.serialize import GetLMPOutput, WriteLMPInput, WriteInvocationInput, LMP + + +@runtime_checkable +class EllSerializer(Protocol): + supports_blobs: bool + + def get_lmp(self, lmp_id: str) -> GetLMPOutput: + ... + + def write_lmp(self, lmp: WriteLMPInput) -> None: + ... + + def write_invocation(self, input: WriteInvocationInput) -> None: + ... + + def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + ... + + def retrieve_blob(self, blob_id: str) -> bytes: + ... + + def close(self): + ... + + def get_lmp_versions(self, fqn: str) -> List[LMP]: + ... + + +@runtime_checkable +class EllAsyncSerializer(Protocol): + supports_blobs: bool + + async def get_lmp(self, lmp_id: str) -> GetLMPOutput: + ... + + async def write_lmp(self, lmp: WriteLMPInput) -> None: + ... + + async def write_invocation(self, input: WriteInvocationInput) -> None: + ... + + async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + ... + + async def retrieve_blob(self, blob_id: str) -> bytes: + ... + + async def close(self): + ... + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + ... diff --git a/src/ell/serialize/serializer.py b/src/ell/serialize/serializer.py new file mode 100644 index 000000000..6efdf718d --- /dev/null +++ b/src/ell/serialize/serializer.py @@ -0,0 +1,80 @@ +from typing import Optional + +from ell.serialize.protocol import EllSerializer, EllAsyncSerializer +from ell.stores.store import AsyncBlobStore, BlobStore +from ell.serialize.config import SerializeConfig +from ell.util.errors import missing_ell_extras + + + +def get_blob_store(config: SerializeConfig) -> Optional[BlobStore]: + if config.minio_endpoint is not None: + try: + from ell.stores.minio import MinioBlobStore, MinioConfig + minio_config = MinioConfig( + endpoint=config.minio_endpoint, + access_key=config.minio_access_key, # type: ignore + secret_key=config.minio_secret_key, # type: ignore + bucket=config.minio_bucket # type: ignore + ) + return MinioBlobStore(minio_config) + except ImportError: + raise missing_ell_extras(message="MinIO storage is not enabled.", extras=["minio"]) + return None + + +def get_serializer(config: SerializeConfig) -> EllSerializer: + blob_store = get_blob_store(config) + if config.pg_connection_string: + try: + from ell.serialize.postgres import PostgresSerializer + return PostgresSerializer(config.pg_connection_string, blob_store) # type: ignore + except ImportError: + raise missing_ell_extras(message="Postgres storage is not enabled.", extras=["postgres"]) + if config.storage_dir: + try: + from ell.serialize.sqlite import SQLiteSerializer + return SQLiteSerializer(config.storage_dir, blob_store) + except ImportError: + raise missing_ell_extras(message="SQLite storage is not enabled.", extras=["sqlite"]) + if config.api_url: + try: + from ell.serialize.http import EllHTTPSerializer + return EllHTTPSerializer(config.api_url) + except ImportError: + raise missing_ell_extras(message="HTTP serialization is not enabled.", extras=["sqlite"]) + + raise ValueError("No storage configuration found.") + + +def get_async_blob_store(config: SerializeConfig) -> Optional[AsyncBlobStore]: + if config.minio_endpoint is not None: + try: + from ell.stores.minio import AsyncMinioBlobStore, MinioConfig + minio_config = MinioConfig( + endpoint=config.minio_endpoint, + access_key=config.minio_access_key, # type: ignore + secret_key=config.minio_secret_key, # type: ignore + bucket=config.minio_bucket # type: ignore + ) + return AsyncMinioBlobStore(minio_config) + except ImportError: + raise missing_ell_extras(message="MinIO storage is not enabled.", extras=["minio"]) + return None + + +def get_async_serializer(config: SerializeConfig) -> EllAsyncSerializer: + blob_store = get_async_blob_store(config) + if config.pg_connection_string: + try: + from ell.serialize.postgres import AsyncPostgresSerializer + return AsyncPostgresSerializer(config.pg_connection_string, blob_store) + except ImportError: + raise missing_ell_extras(message="Postgres storage is not enabled.", extras=["postgres"]) + if config.storage_dir: + try: + from ell.serialize.sqlite import AsyncSQLiteSerializer + return AsyncSQLiteSerializer(config.storage_dir, blob_store) + except ImportError: + raise missing_ell_extras(message="SQLite storage is not enabled.", extras=["sqlite"]) + raise ValueError("No storage configuration found.") diff --git a/src/ell/serialize/sql.py b/src/ell/serialize/sql.py new file mode 100644 index 000000000..0952fd5da --- /dev/null +++ b/src/ell/serialize/sql.py @@ -0,0 +1,92 @@ +from typing import List, Optional, Dict, Any + +from ell.stores.store import Store +from ell.stores.models import Invocation, SerializedLMP +from ell.types.serialize import LMP, WriteLMPInput, WriteInvocationInput +from ell.serialize.protocol import EllSerializer, EllAsyncSerializer + + +class SQLSerializer(EllSerializer): + def __init__(self, store: Store): + self.store = store + self.supports_blobs = store.has_blob_storage + + def get_lmp(self, lmp_id: str): + lmp = self.store.get_lmp(lmp_id) + if lmp: + return LMP(**lmp.model_dump()) + return None + + def get_lmp_versions(self, fqn: str) -> List[LMP]: + slmps = self.store.get_versions_by_fqn(fqn) + return [LMP(**slmp.model_dump()) for slmp in slmps] + + def write_lmp(self, lmp: WriteLMPInput) -> None: + model = SerializedLMP.coerce(lmp) + self.store.write_lmp(model, lmp.uses) + + def write_invocation(self, input: WriteInvocationInput) -> None: + invocation = Invocation.coerce(input.invocation) + self.store.write_invocation(invocation, set(input.consumes)) + return None + + def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + if self.store.blob_store is None: + raise ValueError("Blob store is not enabled") + return self.store.blob_store.store_blob(blob=blob, blob_id=blob_id) + + def retrieve_blob(self, blob_id: str) -> bytes: + if self.store.blob_store is None: + raise ValueError("Blob store is not enabled") + return self.store.blob_store.retrieve_blob(blob_id) + + def close(self): + pass + + +# todo(async): the underlying store and blob store is not async-aware +class AsyncSQLSerializer(EllAsyncSerializer): + def __init__(self, store: Store): + self.store = store + self.supports_blobs = store.has_blob_storage + + async def get_lmp(self, lmp_id: str) -> Optional[LMP]: + lmp = self.store.get_lmp(lmp_id) + if lmp: + return LMP(**lmp.model_dump()) + return None + + async def get_lmp_versions(self, fqn: str) -> List[LMP]: + slmps = self.store.get_versions_by_fqn(fqn) + return [LMP(**slmp.model_dump()) for slmp in slmps] + + async def write_lmp(self, lmp: WriteLMPInput) -> None: + model = SerializedLMP.coerce(lmp) + self.store.write_lmp(model, lmp.uses) + + async def write_invocation(self, input: WriteInvocationInput) -> None: + invocation = Invocation.coerce(input.invocation) + self.store.write_invocation( + invocation, + set(input.consumes) + ) + return None + + async def store_blob(self, blob_id: str, blob: bytes, metadata: Optional[Dict[str, Any]] = None) -> str: + if self.store.blob_store is None: + raise ValueError("Blob store is not enabled") + return self.store.blob_store.store_blob(blob=blob, blob_id=blob_id) + + async def retrieve_blob(self, blob_id: str) -> bytes: + if self.store.blob_store is None: + raise ValueError("Blob store is not enabled") + return self.store.blob_store.retrieve_blob(blob_id) + + async def close(self): + pass + + async def __aenter__(self): + return self + + async def __aexit__(self): + await self.close() diff --git a/src/ell/serialize/sqlite.py b/src/ell/serialize/sqlite.py new file mode 100644 index 000000000..49ea618fb --- /dev/null +++ b/src/ell/serialize/sqlite.py @@ -0,0 +1,18 @@ +from typing import Optional + +from ell.serialize.sql import SQLSerializer, AsyncSQLSerializer +from ell.stores.sql import SQLiteStore +from ell.stores.store import AsyncBlobStore, BlobStore + + + +class SQLiteSerializer(SQLSerializer): + def __init__(self, storage_dir: str, blob_store: Optional[BlobStore] = None): + super().__init__(SQLiteStore(storage_dir, blob_store)) + + +# todo(async). underlying store is not async +class AsyncSQLiteSerializer(AsyncSQLSerializer): + def __init__(self, storage_dir: str, blob_store: Optional[AsyncBlobStore] = None): + super().__init__(SQLiteStore(storage_dir, blob_store)) + diff --git a/src/ell/stores/__init__.py b/src/ell/stores/__init__.py index 45ec0bd33..9ac77a283 100644 --- a/src/ell/stores/__init__.py +++ b/src/ell/stores/__init__.py @@ -1,4 +1,6 @@ try: + # TODO. this will actually be ok once we have stores that do not require sqlmodel, so we may not want to rely on it now + # or have a stores.sql module later import sqlmodel except ImportError: raise ImportError("ell.stores has missing dependencies. Install them with `pip install -U ell-ai[sqlite]` or `pip install -U ell-ai[postgres]`. More info: https://docs.ell.so/installation/custom-installation") diff --git a/src/ell/stores/minio.py b/src/ell/stores/minio.py new file mode 100644 index 000000000..525a055cc --- /dev/null +++ b/src/ell/stores/minio.py @@ -0,0 +1,55 @@ +import io + +from pydantic import BaseModel, Field +import ell.stores.store + +import minio + + +class MinioConfig(BaseModel): + endpoint: str = Field(description="The endpoint of the minio server") + access_key: str = Field(description="The access key of the minio server") + secret_key: str = Field(description="The secret key of the minio server") + bucket: str = Field(description="The bucket to store the blobs in") + + +class MinioBlobStore(ell.stores.store.BlobStore): + def __init__(self, config: MinioConfig): + self.config = config + self.client = minio.Minio( + #todo. support tls with dev vs prod + secure=False,#False if config.endpoint.startswith("localhost") else True, + endpoint=config.endpoint, + access_key=config.access_key, + secret_key=config.secret_key) + + def store_blob(self, blob: bytes, blob_id: str, **kwargs): + self.client.put_object( + bucket_name=self.config.bucket, + object_name=blob_id, + data=io.BytesIO(blob), + length=len(blob) + ) + return blob_id + + def retrieve_blob(self, blob_id: str) -> bytes: + return self.client.get_object(self.config.bucket, blob_id).read() + +# todo. make this actually async +class AsyncMinioBlobStore(ell.stores.store.AsyncBlobStore): + def __init__(self, config: MinioConfig): + self.config = config + self.client = minio.Minio( + config.endpoint, config.access_key, config.secret_key) + + async def store_blob(self, blob: bytes, blob_id: str, **kwargs): + self.client.put_object( + bucket_name=self.config.bucket, + object_name=blob_id, + data=io.BytesIO(blob), + length=len(blob) + ) + return blob_id + + async def retrieve_blob(self, blob_id: str) -> bytes: + return self.client.get_object(self.config.bucket, blob_id).read() diff --git a/src/ell/stores/models/core.py b/src/ell/stores/models/core.py index f339bbab6..7859b9ff6 100644 --- a/src/ell/stores/models/core.py +++ b/src/ell/stores/models/core.py @@ -8,9 +8,6 @@ from ell.types.message import Any, Any, Field, Message, Optional from sqlmodel import Column, Field, SQLModel -from typing import Optional - -from typing import Optional from typing import Dict, List, Union, Any, Optional from pydantic import BaseModel @@ -20,8 +17,16 @@ from sqlmodel import Field, SQLModel, Relationship, JSON, Column from sqlalchemy import Index, func - from typing import Any +import ell.types.serialize + +def utc_now() -> datetime: + """ + Returns the current UTC timestamp. + Serializes to ISO-8601. + """ + return datetime.now(tz=timezone.utc) + class SerializedLMPUses(SQLModel, table=True): """ @@ -95,6 +100,21 @@ class SerializedLMP(SerializedLMPBase, table=True): evaluation_runs: List["SerializedEvaluationRun"] = Relationship(back_populates="evaluated_lmp") + @classmethod + def coerce(cls, input: ell.types.serialize.WriteLMPInput): + return cls( + lmp_id=input.lmp_id, + lmp_type=input.lmp_type, + name=input.name, + source=input.source, + dependencies=input.dependencies, + api_params=input.api_params, + version_number=input.version_number, + initial_global_vars=input.initial_global_vars, + initial_free_vars=input.initial_free_vars, + commit_message=input.commit_message, + created_at=input.created_at + ) class Config: table_name = "serializedlmp" @@ -173,6 +193,10 @@ def should_externalize(self) -> bool: class InvocationContents(InvocationContentsBase, table=True): invocation: "Invocation" = Relationship(back_populates="contents") + @classmethod + def coerce(cls, input: ell.types.serialize.InvocationContents): + return cls(**input.model_dump()) + class Invocation(InvocationBase, table=True): lmp: SerializedLMP = Relationship(back_populates="invocations") @@ -208,3 +232,15 @@ class Invocation(InvocationBase, table=True): ), ) evaluation_result_datapoints: List["EvaluationResultDatapoint"] = Relationship(back_populates="invocation_being_labeled") + + @classmethod + def coerce(cls, input: ell.types.serialize.Invocation): + fields = { + field: getattr(input, field) + for field in input.model_fields + if field != "contents" + } + return cls( + **fields, + contents=InvocationContents.coerce(input.contents) + ) \ No newline at end of file diff --git a/src/ell/stores/sql.py b/src/ell/stores/sql.py index 7766c21b3..1c6e052dd 100644 --- a/src/ell/stores/sql.py +++ b/src/ell/stores/sql.py @@ -1,3 +1,4 @@ +import logging from datetime import datetime, timedelta import os from typing import Any, Optional, Dict, List, Set, Union @@ -6,6 +7,7 @@ from pathlib import Path from typing import Any, Optional, Dict, List, Set from sqlmodel import Session, SQLModel, create_engine, select +import ell.stores.store from ell.stores.migrations import init_or_migrate_database import ell.stores.store from sqlalchemy.sql import text @@ -21,30 +23,35 @@ SerializedEvaluationRun, ) from ell.stores.models.core import InvocationTrace, SerializedLMP, Invocation, InvocationContents -from sqlalchemy import func, and_ +from sqlalchemy import func, and_, Engine from ell.util.serialization import pydantic_ltype_aware_cattr, utc_now import gzip import json from sqlalchemy.exc import IntegrityError -import logging - logger = logging.getLogger(__name__) class SQLStore(ell.stores.store.Store): - def __init__(self, db_uri: str, blob_store: Optional[ell.stores.store.BlobStore] = None): - # XXX: Use Serialization serialzie_object in incoming PR. - self.engine = create_engine( - db_uri, - json_serializer=lambda obj: json.dumps( - pydantic_ltype_aware_cattr.unstructure(obj), - sort_keys=True, - default=repr, - ensure_ascii=False, - ), - ) - + def __init__(self, db_uri: str = None, blob_store: Optional[ell.stores.store.BlobStore] = None, engine: Optional[Engine] = None): + if engine is not None: + self.engine = engine + elif db_uri is None: + raise ValueError( + # todo. better message + "db_uri is required ") + else: + # XXX: Use Serialization serialzie_object in incoming PR. + self.engine = create_engine( + db_uri, + json_serializer=lambda obj: json.dumps( + pydantic_ltype_aware_cattr.unstructure(obj), + sort_keys=True, + default=repr, + ensure_ascii=False, + ), + ) + init_or_migrate_database(self.engine) self.open_files: Dict[str, Dict[str, Any]] = {} super().__init__(blob_store) @@ -168,7 +175,7 @@ def write_evaluation_run(self, evaluation_run: SerializedEvaluationRun) -> int: return evaluation_run.id def write_evaluation_run_intermediate(self, row_result : EvaluationResultDatapoint) -> None: - # add a new result datapoint + # add a new result datapoint with Session(self.engine) as session: session.add(row_result) session.commit() @@ -229,6 +236,13 @@ def get_latest_lmps( session, skip=skip, limit=limit, subquery=subquery, **filters ) + def get_lmp(self, lmp_id: str, session: Optional[Session] = None) -> Optional[SerializedLMP]: + if session is None: + with Session(self.engine) as session: + return session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + else: + return session.exec(select(SerializedLMP).where(SerializedLMP.lmp_id == lmp_id)).first() + def get_lmps( self, session: Session, @@ -444,12 +458,12 @@ def get_eval_versions_by_name(self, name: str) -> List[SerializedEvaluation]: def get_evaluation_run(self, session: Session, run_id: str) -> SerializedEvaluationRun: query = select(SerializedEvaluationRun).where( SerializedEvaluationRun.id == run_id, - + ) result = session.exec(query).one() return result - + def get_evaluation_run_results(self, session: Session, run_id: str, skip: int = 0, limit: int = 100, filters : Optional[Dict[str, Any]] = None) -> List[EvaluationResultDatapoint]: query = select(EvaluationResultDatapoint).where( EvaluationResultDatapoint.evaluation_run_id == run_id @@ -460,23 +474,38 @@ def get_evaluation_run_results(self, session: Session, run_id: str, skip: int = query = query.where(getattr(EvaluationResultDatapoint, key) == value) query = query.offset(skip).limit(limit) - + results = session.exec(query).all() print(f"Found {len(results)} results for run {run_id}") return list(results) class SQLiteStore(SQLStore): - def __init__(self, db_dir: str): + def __init__(self, db_dir: str, blob_store: Optional[ell.stores.store.BlobStore] = None): assert not db_dir.endswith(".db"), "Create store with a directory not a db." + if ":memory:" in db_dir: + from sqlalchemy.pool import StaticPool + # todo. set up blob store for in-memory + engine = create_engine( + "sqlite://", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + json_serializer=lambda obj: + json.dumps(pydantic_ltype_aware_cattr.unstructure(obj), + sort_keys=True, + default=repr, + ensure_ascii=False + ) + ) + super().__init__(engine=engine) + return os.makedirs(db_dir, exist_ok=True) self.db_dir = db_dir db_path = os.path.join(db_dir, "ell.db") - blob_store = SQLBlobStore(db_dir) + blob_store = SQLBlobStore(db_dir) if blob_store is None else blob_store super().__init__(f"sqlite:///{db_path}", blob_store=blob_store) - class SQLBlobStore(ell.stores.store.BlobStore): def __init__(self, db_dir: str): self.db_dir = db_dir @@ -507,3 +536,5 @@ def _get_blob_path(self, id: str, depth: int = 2) -> str: class PostgresStore(SQLStore): def __init__(self, db_uri: str): super().__init__(db_uri) + logger.debug("Postgres store initialized") + diff --git a/src/ell/stores/store.py b/src/ell/stores/store.py index 737743976..a85328dc2 100644 --- a/src/ell/stores/store.py +++ b/src/ell/stores/store.py @@ -1,16 +1,23 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime -from typing import Any, Optional, Dict, List, Set, Union +from typing import Any, Optional, Dict, List, Set, Union, TYPE_CHECKING + from ell.types._lstr import _lstr from ell.stores.models.core import SerializedLMP, Invocation from ell.types.message import InvocableLM from ell.stores.models.evaluations import EvaluationResultDatapoint, EvaluationRunLabelerSummary, SerializedEvaluation, SerializedEvaluationRun # from ell.types.studio import SerializedEvaluation, SerializedEvaluationRun +if TYPE_CHECKING: + from sqlmodel import Session +else: + Session = None + + class BlobStore(ABC): @abstractmethod - def store_blob(self, blob: bytes, blob_id : str) -> str: + def store_blob(self, blob: bytes, blob_id : str, metadata: Optional[Dict[str, Any]] = None) -> str: """Store a blob and return its identifier.""" pass @@ -19,6 +26,18 @@ def retrieve_blob(self, blob_id: str) -> bytes: """Retrieve a blob by its identifier.""" pass + +class AsyncBlobStore(BlobStore): + @abstractmethod + async def store_blob(self, blob: bytes, blob_id: str, metadata: Optional[Dict[str, Any]] = None) -> str: + """Store a blob and return its identifier.""" + pass + + @abstractmethod + async def retrieve_blob(self, blob_id: str) -> bytes: + """Retrieve a blob by its identifier.""" + pass + class Store(ABC): """ Abstract base class for serializers. Defines the interface for serializing and deserializing LMPs and invocations. @@ -124,6 +143,12 @@ def get_eval_versions_by_name(self, name: str) -> List[SerializedEvaluation]: pass + @abstractmethod + def get_lmp(self, lmp_id: str, session: Optional[Session] = None) -> Optional[SerializedLMP]: + """ + Get an LMP by its id. + """ + pass @contextmanager def freeze(self, *lmps: InvocableLM): diff --git a/src/ell/studio/__main__.py b/src/ell/studio/__main__.py index 484aa18bf..3206145c8 100644 --- a/src/ell/studio/__main__.py +++ b/src/ell/studio/__main__.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import socket import time import webbrowser @@ -32,12 +33,33 @@ def _setup_logging(level): def main(): parser = ArgumentParser(description="ell studio") - parser.add_argument("--storage-dir" , default=None, - help="Directory for filesystem serializer storage (default: current directory)") - parser.add_argument("--pg-connection-string", default=None, - help="PostgreSQL connection string (default: None)") - parser.add_argument("--host", default="127.0.0.1", help="Host to run the server on (default: localhost)") - parser.add_argument("--port", type=int, default=5555, help="Port to run the server on (default: 5555)") + parser.add_argument("--storage-dir" , + default=os.getenv("ELL_STORAGE_DIR"), + help="Directory for filesystem serialize storage (default: None, env: ELL_STORAGE_DIR)") + parser.add_argument("--pg-connection-string", + default=os.getenv("ELL_PG_CONNECTION_STRING"), + help="PostgreSQL connection string (default: None, env: ELL_PG_CONNECTION_STRING)") + parser.add_argument("--mqtt-connection-string", + default=os.getenv("ELL_MQTT_CONNECTION_STRING"), + help="MQTT connection string (default: None, env: ELL_MQTT_CONNECTION_STRING)") + parser.add_argument("--minio-endpoint", + default=os.getenv("ELL_MINIO_ENDPOINT"), + help="MinIO endpoint (default: None, env: ELL_MINIO_ENDPOINT)") + parser.add_argument("--minio-access-key", + default=os.getenv("ELL_MINIO_ACCESS_KEY"), + help="MinIO access key (default: None, env: ELL_MINIO_ACCESS_KEY)") + parser.add_argument("--minio-secret-key", + default=os.getenv("ELL_MINIO_SECRET_KEY"), + help="MinIO secret key (default: None, env: ELL_MINIO_SECRET_KEY)") + parser.add_argument("--minio-bucket", default=os.getenv("ELL_MINIO_BUCKET"), + help="MinIO bucket (default: None, env: ELL_MINIO_BUCKET)") + parser.add_argument("--host", + default=os.getenv("ELL_STUDIO_HOST") or "0.0.0.0", + help="Host to run the server on (default: 0.0.0.0, env: ELL_STUDIO_HOST)") + parser.add_argument("--port", + type=int, + default=int(os.getenv("ELL_STUDIO_PORT") or 5555), + help="Port to run the server on (default: 5555, env: ELL_STUDIO_PORT)") parser.add_argument("--dev", action="store_true", help="Run in development mode") parser.add_argument("--dev-static-dir", default=None, help="Directory to serve static files from in development mode") parser.add_argument("--open", action="store_true", help="Opens the studio web UI in a browser") @@ -47,10 +69,17 @@ def main(): _setup_logging(logging.DEBUG if args.verbose else logging.INFO) if args.dev: - assert args.port == 5555, "Port must be 5000 in development mode" - - config = Config.create(storage_dir=args.storage_dir, - pg_connection_string=args.pg_connection_string) + assert args.port == 5555, "Port must be 5555 in development mode" + + config = Config.create( + storage_dir=args.storage_dir, + pg_connection_string=args.pg_connection_string, + mqtt_connection_string=args.mqtt_connection_string, + minio_endpoint=args.minio_endpoint, + minio_access_key=args.minio_access_key, + minio_secret_key=args.minio_secret_key, + minio_bucket=args.minio_bucket + ) app = create_app(config) if not args.dev: diff --git a/src/ell/studio/config.py b/src/ell/studio/config.py index e56a834d7..b54fce265 100644 --- a/src/ell/studio/config.py +++ b/src/ell/studio/config.py @@ -18,16 +18,23 @@ def ell_home() -> str: class Config(BaseModel): pg_connection_string: Optional[str] = None storage_dir: Optional[str] = None + mqtt_connection_string: Optional[str] = None + minio_endpoint: Optional[str] = None + minio_access_key: Optional[str] = None + minio_secret_key: Optional[str] = None + minio_bucket: Optional[str] = None @classmethod def create( cls, storage_dir: Optional[str] = None, pg_connection_string: Optional[str] = None, + mqtt_connection_string: Optional[str] = None, + minio_endpoint: Optional[str] = None, + minio_access_key: Optional[str] = None, + minio_secret_key: Optional[str] = None, + minio_bucket: Optional[str] = None, ) -> 'Config': - pg_connection_string = pg_connection_string or os.getenv("ELL_PG_CONNECTION_STRING") - storage_dir = storage_dir or os.getenv("ELL_STORAGE_DIR") - # Enforce that we use either sqlite or postgres, but not both if pg_connection_string is not None and storage_dir is not None: raise ValueError("Cannot use both sqlite and postgres") @@ -37,4 +44,12 @@ def create( # This intends to honor the default we had set in the CLI storage_dir = os.getcwd() - return cls(pg_connection_string=pg_connection_string, storage_dir=storage_dir) + return cls( + pg_connection_string=pg_connection_string, + storage_dir=storage_dir, + mqtt_connection_string=mqtt_connection_string, + minio_endpoint=minio_endpoint, + minio_access_key=minio_access_key, + minio_secret_key=minio_secret_key, + minio_bucket=minio_bucket + ) diff --git a/src/ell/studio/server.py b/src/ell/studio/server.py index 40ea6a809..a28d7a9e9 100644 --- a/src/ell/studio/server.py +++ b/src/ell/studio/server.py @@ -1,6 +1,11 @@ +import asyncio +from contextlib import asynccontextmanager from typing import Optional, Dict, Any, List from sqlmodel import Session + +from ell.serialize.serializer import get_serializer, get_blob_store +from ell.serialize.config import SerializeConfig from ell.stores.sql import PostgresStore, SQLiteStore from ell import __version__ from fastapi import FastAPI, Query, HTTPException, Depends, Response, WebSocket, WebSocketDisconnect @@ -8,15 +13,18 @@ import logging import json from ell.studio.config import Config -from ell.studio.connection_manager import ConnectionManager from ell.studio.datamodels import EvaluationResultDatapointPublic, InvocationPublicWithConsumes, SerializedLMPWithUses, EvaluationPublic, SpecificEvaluationRunPublic from ell.stores.models.core import SerializedLMP from datetime import datetime, timedelta from sqlmodel import select +from contextlib import AsyncExitStack from ell.stores.models.evaluations import SerializedEvaluation +from ell.api.pubsub.abc import PubSub +from ell.api.pubsub.websocket import WebSocketPubSub + logger = logging.getLogger(__name__) @@ -24,14 +32,42 @@ def get_serializer(config: Config): + serialize_config = SerializeConfig(**config.model_dump()) + blob_store = get_blob_store(serialize_config) if config.pg_connection_string: - return PostgresStore(config.pg_connection_string) + return PostgresStore(config.pg_connection_string, blob_store) elif config.storage_dir: - return SQLiteStore(config.storage_dir) + return SQLiteStore(config.storage_dir, blob_store) else: raise ValueError("No storage configuration found") +pubsub: Optional[PubSub] = None + +async def get_pubsub(): + yield pubsub + + +async def setup_pubsub(config: Config, exit_stack: AsyncExitStack): + """Set up the appropriate pubsub client based on configuration.""" + if config.storage_dir is not None: + return WebSocketPubSub(), None + + if config.mqtt_connection_string is not None: + try: + from ell.api.pubsub.mqtt import setup + except ImportError as e: + raise ImportError( + "Received mqtt_connection_string but dependencies missing. Install with `pip install -U ell-ai[mqtt]. More info: https://docs.ell.so/installation") from e + + pubsub, mqtt_client = await setup(config.mqtt_connection_string) + exit_stack.push_async_exit(mqtt_client) + logger.info("Connected to MQTT") + + loop = asyncio.get_event_loop() + return pubsub, pubsub.listen(loop) + + return None, None def create_app(config:Config): serializer = get_serializer(config) @@ -40,7 +76,29 @@ def get_session(): with Session(serializer.engine) as session: yield session - app = FastAPI(title="ell Studio", version=__version__) + @asynccontextmanager + async def lifespan(app: FastAPI): + global pubsub + exit_stack = AsyncExitStack() + pubsub_task = None + + try: + pubsub, pubsub_task = await setup_pubsub(config, exit_stack) + yield + + finally: + if pubsub_task and not pubsub_task.done(): + pubsub_task.cancel() + try: + await pubsub_task + except asyncio.CancelledError: + pass + + await exit_stack.aclose() + pubsub = None + + + app = FastAPI(title="ell Studio", version=__version__, lifespan=lifespan) # Enable CORS for all origins app.add_middleware( @@ -51,17 +109,21 @@ def get_session(): allow_headers=["*"], ) - manager = ConnectionManager() @app.websocket("/ws") - async def websocket_endpoint(websocket: WebSocket): - await manager.connect(websocket) + async def websocket_endpoint(websocket: WebSocket,pubsub: PubSub = Depends(get_pubsub)): + await websocket.accept() + # NB. for now, studio does not dynamically subscribe to data topics. We subscribe every client to these by + # default. If desired, apps may issue a "subscribe" message that we can handle in websocket.receive_text below + # to sign up to receive data from arbitrary topics. They can unsubscribe when done via an "unsubscribe" message. + await pubsub.subscribe_async("all", websocket) + await pubsub.subscribe_async("lmp/#", websocket) try: while True: data = await websocket.receive_text() # Handle incoming WebSocket messages if needed except WebSocketDisconnect: - manager.disconnect(websocket) + pubsub.unsubscribe_from_all(websocket) @app.get("/api/latest/lmps", response_model=list[SerializedLMPWithUses]) @@ -195,9 +257,13 @@ def get_lmp_history( return history + # Used by studio to publish changes from a SQLLite store directly async def notify_clients(entity: str, id: Optional[str] = None): + if pubsub is None: + logger.error("No pubsub client, cannot notify clients") + return message = json.dumps({"entity": entity, "id": id}) - await manager.broadcast(message) + await pubsub.publish("all", message) # Add this method to the app object app.notify_clients = notify_clients @@ -244,7 +310,7 @@ def get_evaluations( return evaluations - + @app.get("/api/latest/evaluations", response_model=List[EvaluationPublic]) def get_latest_evaluations( skip: int = Query(0, ge=0), @@ -268,8 +334,8 @@ def get_evaluation( if not evaluation: raise HTTPException(status_code=404, detail="Evaluation not found") return evaluation[0] - - + + @app.get("/api/evaluation-runs/{run_id}", response_model=SpecificEvaluationRunPublic) def get_evaluation_run( @@ -278,7 +344,7 @@ def get_evaluation_run( ): runs = serializer.get_evaluation_run(session, run_id) return runs - + @app.get("/api/evaluation-runs/{run_id}/results", response_model=List[EvaluationResultDatapointPublic]) def get_evaluation_run_results( run_id: str, @@ -293,7 +359,7 @@ def get_evaluation_run_results( limit=limit, ) return results - + @app.get("/api/all-evaluations", response_model=List[EvaluationPublic]) def get_all_evaluations( skip: int = Query(0, ge=0), @@ -309,7 +375,7 @@ def get_all_evaluations( ) results = session.exec(query).all() return list(results) - + @app.get("/api/dataset/{dataset_id}") def get_dataset( dataset_id: str, @@ -317,27 +383,27 @@ def get_dataset( ): if not serializer.blob_store: raise HTTPException(status_code=400, detail="Blob storage not configured") - + try: # Get the blob data blob_data = serializer.blob_store.retrieve_blob(dataset_id) - + # Check if size is under 5MB if len(blob_data) > 5 * 1024 * 1024: # 5MB in bytes raise HTTPException( status_code=413, detail="Dataset too large to preview (>5MB)" ) - + # Decode and parse JSON dataset_json = json.loads(blob_data.decode('utf-8')) - + return { "size": len(blob_data), "data": dataset_json } - + except FileNotFoundError: raise HTTPException(status_code=404, detail="Dataset not found") except json.JSONDecodeError: @@ -345,5 +411,5 @@ def get_dataset( except Exception as e: logger.error(f"Error retrieving dataset: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") - + return app diff --git a/src/ell/types/_lstr.py b/src/ell/types/_lstr.py index 55f5327a4..47c6fab4c 100644 --- a/src/ell/types/_lstr.py +++ b/src/ell/types/_lstr.py @@ -99,6 +99,7 @@ def __new__( instance = super(_lstr, cls).__new__(cls, content) # instance._logits = logits if isinstance(origin_trace, str): + # TODO. pydantic validation splits on ',', it would be good to have this in one place or standardize on a list for the serialized format unless ',' denotes something else instance.__origin_trace__ = frozenset({origin_trace}) else: instance.__origin_trace__ = ( @@ -116,8 +117,8 @@ def __get_pydantic_core_schema__( def validate_lstr(value): if isinstance(value, dict) and value.get("__lstr", False): content = value["content"] - origin_trace = value["__origin_trace__"].split(",") - return cls(content, origin_trace=origin_trace) + origin_trace = value["__origin_trace__"].split(",") if isinstance(value["__origin_trace__"], str) else frozenset(value["__origin_trace__"]) + return cls(content, origin_trace=origin_trace) # type: ignore elif isinstance(value, str): return cls(value) elif isinstance(value, cls): diff --git a/src/ell/types/message.py b/src/ell/types/message.py index 61acb4ad6..e11566e09 100644 --- a/src/ell/types/message.py +++ b/src/ell/types/message.py @@ -1,19 +1,19 @@ # todo: implement tracing for structured outs. this a v2 feature. +import base64 import json -from ell.types._lstr import _lstr +from concurrent.futures import ThreadPoolExecutor, as_completed from functools import cached_property -import numpy as np -import base64 from io import BytesIO -from PIL import Image as PILImage - -from pydantic import BaseModel, ConfigDict, Field, model_validator, field_serializer +from types import FunctionType +from typing import Any, Callable, Dict, List, Optional, Union -from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +from PIL import Image as PILImage +from pydantic import BaseModel, ConfigDict, Field, model_validator, field_serializer, field_validator -from typing import Any, Callable, Dict, List, Optional, Union +from ell.types._lstr import _lstr +from ell.util.serialization import serialize_image, unstructure_lstr -from ell.util.serialization import serialize_image _lstr_generic = Union[_lstr, str] InvocableTool = Callable[..., Union["ToolResult", _lstr_generic, List["ContentBlock"], ]] @@ -22,8 +22,8 @@ class ToolResult(BaseModel): - tool_call_id: _lstr_generic - result: List["ContentBlock"] + tool_call_id: _lstr_generic = Field(description="Id of the tool call from the model that led the tool to be called (`'call_{id}'`)") + result: List["ContentBlock"] = Field(description="Tool call output as a list of ell ContentBlocks") @property def text(self) -> str: @@ -40,35 +40,75 @@ def text_only(self) -> str: def __repr__(self): return f"{self.__class__.__name__}(tool_call_id={self.tool_call_id}, result={_content_to_text(self.result)})" +class ToolReference(BaseModel): + """A reference to an invocable tool""" + fqn: str = Field(description="The fully qualified name of the tool") + hash: str = Field(description="The hash of the tool and its dependencies") + class ToolCall(BaseModel): - tool : InvocableTool - tool_call_id : Optional[_lstr_generic] = Field(default=None) - params : BaseModel + tool: Union[InvocableTool, ToolReference] = Field(description="The tool function to call or a reference to it when serialized") + tool_call_id: Optional[_lstr_generic] = Field(default=None) + params: Union[Dict[str, Any], BaseModel] = Field(description="Arguments for the tool call provided by the model.") + + def __init__(self, tool, params: Optional[Union[BaseModel, Dict[str, Any]]], tool_call_id: Optional[_lstr_generic]=None): + if (not isinstance(params, BaseModel)) and isinstance(tool, FunctionType) and hasattr(tool, '__ell_params_model__'): + params = tool.__ell_params_model__(**params) + if isinstance(tool_call_id, dict): + tool_call_id = _lstr(content=tool_call_id['content'], origin_trace=tool_call_id.get('__origin_trace__'), logits=tool_call_id.get('logits')) - def __init__(self, tool, params : Union[BaseModel, Dict[str, Any]], tool_call_id=None): - if not isinstance(params, BaseModel): - params = tool.__ell_params_model__(**params) #convenience. super().__init__(tool=tool, tool_call_id=tool_call_id, params=params) + @field_serializer('tool') + def serialize_tool(self, tool: Union[InvocableTool, ToolReference], _info): + if isinstance(tool, ToolReference): + return tool + return ToolReference( + # todo(alex). add the value of fqn we want to standardize on to all lmps so we don't keep using qualname + fqn=tool.__qualname__, + hash=getattr(tool, '__ell_hash__', 'unknown') + ) + + @field_serializer('params') + def _serialize_params(self, params: Union[Dict[str, Any], BaseModel]) -> Dict[str, Any]: + if isinstance(params, dict): + return params + return params.model_dump(exclude_none=True, exclude_unset=True) + + def serialize_params(self) -> Dict[str, Any]: + return self._serialize_params(self.params) + + @field_serializer('tool_call_id') + def serialize_tool_call_id(self, tool_call_id: _lstr_generic): + if tool_call_id is None: + return None + origin_trace = tool_call_id.__dict__['__origin_trace__'] + if origin_trace: + return unstructure_lstr(tool_call_id) + return tool_call_id + def __call__(self, **kwargs): assert not kwargs, "Unexpected arguments provided. Calling a tool uses the params provided in the ToolCall." + assert not isinstance(self.tool, ToolReference), f"Tools are not invocable once serialized. ToolCall.tool is a ToolReference: {self.tool}" # XXX: TODO: MOVE TRACKING CODE TO _TRACK AND OUT OF HERE AND API. - return self.tool(**self.params.model_dump()) + return self.tool(**self.serialize_params()) # XXX: Deprecate in 0.1.0 def call_and_collect_as_message_block(self): raise DeprecationWarning("call_and_collect_as_message_block is deprecated. Use collect_as_content_block instead.") def call_and_collect_as_content_block(self): - res = self.tool(**self.params.model_dump(), _tool_call_id=self.tool_call_id) + if isinstance(self.tool, ToolReference): + raise ValueError(f"Cannot call a tool that is a ToolReference: {self.tool}") + res = self.tool(**self.serialize_params(), + _tool_call_id=self.tool_call_id) return ContentBlock(tool_result=res) def call_and_collect_as_message(self): return Message(role="user", content=[self.call_and_collect_as_message_block()]) def __repr__(self): - return f"{self.__class__.__name__}({self.tool.__name__}({self.params}), tool_call_id='{self.tool_call_id}')" + return f"{self.__class__.__name__}({self.tool.__name__ if hasattr(self.tool, '__name__') else str(self.tool)}({self.params}), tool_call_id='{self.tool_call_id}')" class ImageContent(BaseModel): @@ -129,13 +169,13 @@ class ContentBlock(BaseModel): image: Optional[ImageContent] = Field(default=None) audio: Optional[Union[np.ndarray, List[float]]] = Field(default=None) tool_call: Optional[ToolCall] = Field(default=None) - parsed: Optional[BaseModel] = Field(default=None) + parsed: Optional[Union[Dict[str, Any], BaseModel]] = Field(default=None) tool_result: Optional[ToolResult] = Field(default=None) # TODO: Add a JSON type? This would be nice for response_format. This is different than resposne_format = model. Or we could be opinionated and automatically parse the json response. That might be nice. # This breaks us maintaing parity with the openai python client in some sen but so does image. def __init__(self, *args, **kwargs): - if "image" in kwargs and not isinstance(kwargs["image"], ImageContent): + if "image" in kwargs and kwargs['image'] is not None and not isinstance(kwargs["image"], ImageContent): # todo(alex). are we looking for dict here? im = kwargs["image"] = ImageContent.coerce(kwargs["image"]) # XXX: Backwards compatibility, Deprecate. if (d := kwargs.get("image_detail", None)): im.detail = d @@ -175,7 +215,7 @@ def type(self): @property def content(self): - return getattr(self, self.type) + return getattr(self, self.type) # type: ignore @classmethod def coerce(cls, content: AnyContent) -> "ContentBlock": @@ -255,14 +295,25 @@ def coerce(cls, content: AnyContent) -> "ContentBlock": return cls(image=ImageContent.coerce(content)) if isinstance(content, BaseModel): return cls(parsed=content) + if isinstance(content, dict): + return cls(**content) - raise ValueError(f"Invalid content type: {type(content)}") + raise ValueError(f"Invalid ContentBlock content type: {type(content)}") @field_serializer('parsed') def serialize_parsed(self, value: Optional[BaseModel], _info): if value is None: return None return value.model_dump(exclude_none=True, exclude_unset=True) + + @field_validator('parsed' ,mode='wrap') + def deserialize_parsed(cls, value: Optional[Union[Dict[str, Any],BaseModel]], _info): + # Why must we do this? + # pydantic returns an empty BaseModel() whenever parsed is a dict + if value is None or isinstance(value, (dict,BaseModel)): + return value + raise ValueError(f"Invalid ContentBlock.parsed value: {type(value)}") + def to_content_blocks( @@ -303,7 +354,7 @@ def to_content_blocks( if not isinstance(content, list): content = [content] - + return [ContentBlock.model_validate(ContentBlock.coerce(c)) for c in content] @@ -437,22 +488,7 @@ def serialize_content(self, content: List[ContentBlock]): for block in content ] - @classmethod - def model_validate(cls, obj: Any) -> 'Message': - """Custom validation to handle deserialization""" - if isinstance(obj, dict): - if 'content' in obj and isinstance(obj['content'], list): - content_blocks = [] - for block in obj['content']: - if isinstance(block, dict): - if 'text' in block: - block['text'] = str(block['text']) if block['text'] is not None else None - content_blocks.append(ContentBlock.model_validate(block)) - else: - content_blocks.append(ContentBlock.coerce(block)) - obj['content'] = content_blocks - return super().model_validate(obj) - + #todo(alex): needed? @classmethod def model_validate_json(cls, json_str: str) -> 'Message': """Custom validation to handle deserialization from JSON string""" diff --git a/src/ell/types/serialize.py b/src/ell/types/serialize.py new file mode 100644 index 000000000..913f366a6 --- /dev/null +++ b/src/ell/types/serialize.py @@ -0,0 +1,161 @@ +import uuid +from datetime import datetime, timezone +from functools import cached_property +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, AwareDatetime, Field, field_serializer, field_validator + +from ell.types.lmp import LMPType +from ell.types.message import Message + + +def utc_now() -> datetime: + """ + Returns the current UTC timestamp. + Serializes to ISO-8601. + """ + return datetime.now(tz=timezone.utc) + + +class WriteLMPInput(BaseModel): + """ + Arguments to write a LMP. + """ + lmp_id: str + name: str + source: str + dependencies: str + lmp_type: LMPType + api_params: Optional[Dict[str, Any]] = None + initial_free_vars: Optional[Dict[str, Any]] = None + initial_global_vars: Optional[Dict[str, Any]] = None + created_at: AwareDatetime = Field(default_factory=utc_now) + uses: List[str] = Field(default_factory=list) + commit_message: Optional[str] = None + version_number: Optional[int] = None + + +# todo. see if we can get rid of this...the only difference with writelmpinput is some properties are read only +class LMP(BaseModel): + lmp_id: str + name: str + source: str + dependencies: str + lmp_type: LMPType + api_params: Optional[Dict[str, Any]] + initial_free_vars: Optional[Dict[str, Any]] + initial_global_vars: Optional[Dict[str, Any]] + created_at: AwareDatetime + version_number: int + commit_message: Optional[str] + num_invocations: int + + +class GetLMPInput(BaseModel): + id: str + + +GetLMPOutput = Optional[LMP] + + +class InvocationContents(BaseModel): + invocation_id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="ID of the invocation the contents belong to") + params: Optional[Dict[str, Any]] = Field(description="The parameters of the LMP at the time of the invocation", default=None) + results: Optional[List[Message]] = Field(description="The output of the invocation as a list of ell Messages", default=None) + invocation_api_params: Optional[Dict[str, Any]] = Field(description="Arguments the model API was called with", default=None) + global_vars: Optional[Dict[str, Any]] = Field(description="Global variable bindings and their values at the time of the invocation", default=None) + free_vars: Optional[Dict[str, Any]] = Field(description="Free variable bindings and their values at the time of the invocation", default=None) + is_external: bool = Field(default=False, description="Whether the invocation contents are stored externally in a blob store. If they are they can be retrieved by 'invocation-{invocation_id}'.") + + @cached_property + def total_size_bytes(self) -> int: + """ + Returns the total uncompressed size of the invocation contents as JSON in bytes. + """ + import json + json_fields = [ + self.params, + self.results, + self.invocation_api_params, + self.global_vars, + self.free_vars + ] + # todo(alex): we may want to bring this in line with other json serialization + return sum( + len(json.dumps(field, default=(lambda x: json.dumps(x.model_dump(), default=str, ensure_ascii=False) + if isinstance(x, BaseModel) else str(x)), ensure_ascii=False).encode('utf-8')) + for field in json_fields if field is not None + ) + + @cached_property + def should_externalize(self) -> bool: + return self.total_size_bytes > 102400 # Precisely 100kb in bytes + + +class Invocation(BaseModel): + """ + An invocation of an LMP. + """ + id: Optional[str] = None + lmp_id: str + latency_ms: float + prompt_tokens: Optional[int] = None + completion_tokens: Optional[int] = None + state_cache_key: Optional[str] = None + created_at: AwareDatetime = Field(default_factory=utc_now) + used_by_id: Optional[str] = None + contents: InvocationContents + + # Note: we must set to always right now, because the global json serializer calls model_dump instead of + # model_dump_json and then json.dumps with default of repr. would prefer when_used=json but + # tbh it's probably not needed as i think pydantic already handles this for json + + @field_serializer('created_at', when_used='always') + def serialize_date(self, created_at: AwareDatetime): + return str(created_at) + + @field_validator('created_at', mode="before") + def deserialize_and_validate_date(cls, created_at: Union[str, AwareDatetime]): + if isinstance(created_at, str): + dt = datetime.fromisoformat(created_at) + if dt.tzinfo is None: + raise ValueError( + "Datetime string must include timezone information") + return dt + return created_at + + +class WriteInvocationInput(BaseModel): + """ + Arguments to write an invocation. + """ + invocation: Invocation + consumes: List[str] + + +class LMPInvokedEvent(BaseModel): + lmp_id: str + # invocation_id: str + consumes: List[str] + + +class WriteBlobInput(BaseModel): + """ + Arguments to write a blob to a blob store + """ + blob_id: str + blob: bytes + metadata: Optional[Dict[str, Any]] = None + + +# class Blob(BaseModel): +# blob_id: str +# blob: bytes +# content_type: str +# metadata: Optional[Dict[str, Any]] = None +# +# @cached_property +# def size_bytes(self) -> int: +# return len(self.blob) +# +# diff --git a/src/ell/util/errors.py b/src/ell/util/errors.py new file mode 100644 index 000000000..3805df15d --- /dev/null +++ b/src/ell/util/errors.py @@ -0,0 +1,6 @@ +from typing import List + +def missing_ell_extras(message: str, extras: List[str]): + return ImportError( + f"{message}. Enable them with `pip install -U ell-api[{','.join(extras)}]`. More info: https://docs.ell.so/installation" + ) \ No newline at end of file diff --git a/src/ell/util/serialization.py b/src/ell/util/serialization.py index 48e41bcb4..e73ca19d5 100644 --- a/src/ell/util/serialization.py +++ b/src/ell/util/serialization.py @@ -54,7 +54,19 @@ def serialize_image(img): ) def unstructure_lstr(obj): - return dict(content=str(obj), **obj.__dict__, __lstr=True) + if isinstance(obj, str): + return dict(content=obj, __lstr=True) + origin_trace = obj.__dict__.__origin_trace__ + if origin_trace and isinstance(origin_trace, frozenset): + return dict(content=str(obj), + **obj.__dict__, + origin_trace=list(sorted(origin_trace)), + __lstr=True) + + return dict(content=str(obj), + **obj.__dict__, + __lstr=True) + pydantic_ltype_aware_cattr.register_unstructure_hook( _lstr, diff --git a/tests/api/test_api.py b/tests/api/test_api.py new file mode 100644 index 000000000..769da2af1 --- /dev/null +++ b/tests/api/test_api.py @@ -0,0 +1,367 @@ +from datetime import timezone +from logging import DEBUG +from uuid import uuid4 +import pytest +from typing import Any, Dict, Tuple + +from fastapi import FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel, Field, ValidationError + +import ell +from ell import Message +from ell.serialize.http import EllHTTPSerializer +from ell.serialize.sqlite import SQLiteSerializer, AsyncSQLiteSerializer +from ell.api.server import create_app, get_pubsub, get_serializer +from ell.api.config import Config +from ell.api.logger import setup_logging +from ell.types import ToolCall +from ell.types.serialize import WriteInvocationInput, utc_now, Invocation, InvocationContents +from ell.stores.models import SerializedLMP +from ell.types.lmp import LMPType +from ell.types.serialize import WriteLMPInput + + +@pytest.fixture +def sqlite_serializer() -> SQLiteSerializer: + return SQLiteSerializer(":memory:") + + +@pytest.fixture +def async_sqlite_serializer() -> AsyncSQLiteSerializer: + return AsyncSQLiteSerializer(":memory:") + + +def test_construct_serialized_lmp(): + serialized_lmp = SerializedLMP( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + api_params={"param1": "value1"}, + version_number=1, + # uses={"used_lmp_1": {}, "used_lmp_2": {}}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + created_at=utc_now() + ) + assert serialized_lmp.lmp_id == "test_lmp_id" + assert serialized_lmp.name == "Test LMP" + assert serialized_lmp.source == "def test_function(): pass" + assert serialized_lmp.dependencies == str(["dep1", "dep2"]) + assert serialized_lmp.api_params == {"param1": "value1"} + assert serialized_lmp.version_number == 1 + assert serialized_lmp.created_at is not None + + +def test_write_lmp_input(): + # Should be able to construct a WriteLMPInput from data + input = WriteLMPInput( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + api_params={"param1": "value1"}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + version_number=1, + ) + + # Should default a created_at to utc_now + assert input.created_at is not None + assert input.created_at.tzinfo == timezone.utc + + # Should be able to construct a SerializedLMP from a WriteLMPInput + model = SerializedLMP(**input.model_dump()) + assert model.created_at == input.created_at + + input2 = WriteLMPInput( + lmp_id="test_lmp_id", + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + api_params={"param1": "value1"}, + initial_global_vars={"global_var1": "value1"}, + initial_free_vars={"free_var1": "value2"}, + commit_message="Initial commit", + version_number=1, + # should work with an isoformat string + created_at=utc_now().isoformat() # type: ignore + ) + model2 = SerializedLMP(**input2.model_dump()) + assert model2.created_at == input2.created_at + assert input2.created_at is not None + assert input2.created_at.tzinfo == timezone.utc + + +def create_test_app(serializer: AsyncSQLiteSerializer) -> Tuple[FastAPI, EllHTTPSerializer, None, Config]: + setup_logging(DEBUG) + config = Config(storage_dir=":memory:") + app = create_app(config) + + publisher = None + + async def get_publisher_override(): + yield publisher + + def get_serializer_override(): + return serializer + + app.dependency_overrides[get_pubsub] = get_publisher_override + app.dependency_overrides[get_serializer] = get_serializer_override + + client = EllHTTPSerializer(client=TestClient(app)) + + return app, client, publisher, config + + +def test_write_lmp(async_sqlite_serializer: AsyncSQLiteSerializer): + _app, client, *_ = create_test_app(async_sqlite_serializer) + + lmp_data: Dict[str, Any] = { + "lmp_id": uuid4().hex, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "lmp_type": LMPType.LM, + "api_params": {"param1": "value1"}, + "version_number": 1, + # "uses": {"used_lmp_1": {}, "used_lmp_2": {}}, + "initial_global_vars": {"global_var1": "value1"}, + "initial_free_vars": {"free_var1": "value2"}, + "commit_message": "Initial commit", + "created_at": utc_now().isoformat().replace("+00:00", "Z"), + "uses": ['used_lmp_1'] + } + + response = client.client.post("/lmp", json=lmp_data) + + # response = client.write_lmp( + # WriteLMPInput(**lmp_data), + # ) + + assert response.status_code == 200 + + lmp = client.client.get(f"/lmp/{lmp_data['lmp_id']}") + assert lmp.status_code == 200 + del lmp_data["uses"] # todo. return uses y/n? + assert lmp.json() == {**lmp_data, "num_invocations": 0} + + +def test_write_invocation(async_sqlite_serializer: AsyncSQLiteSerializer): + _app, client, *_ = create_test_app(async_sqlite_serializer) + # Test basic http client functionality + client = client.client + + # first write an lmp.. + lmp_id = uuid4().hex + lmp_data: Dict[str, Any] = { + "lmp_id": lmp_id, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "lmp_type": LMPType.LM, + "api_params": {"param1": "value1"}, + } + + response = client.post("/lmp", json=lmp_data) + + try: + assert response.status_code == 200 + except Exception as e: + print(response.json()) + raise e + + invocation_data = { + "id": uuid4().hex, + "lmp_id": lmp_id, + "args": ["arg1", "arg2"], + "kwargs": {"kwarg1": "value1"}, + "global_vars": {"global_var1": "value1"}, + "free_vars": {"free_var1": "value2"}, + "latency_ms": 100.0, + "invocation_kwargs": {"model": "gpt-4o", "messages": [{"role": "system", + "content": "You are a JSON parser. You respond only in JSON. Do not format using markdown."}, + {"role": "user", + "content": "You are given the following task: \"What is two plus two?\"\n Parse the task into the following type:\n {'$defs': {'Add': {'properties': {'op': {'const': '+', 'enum': ['+'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Add', 'type': 'object'}, 'Div': {'properties': {'op': {'const': '/', 'enum': ['/'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Div', 'type': 'object'}, 'Mul': {'properties': {'op': {'const': '*', 'enum': ['*'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Mul', 'type': 'object'}, 'Sub': {'properties': {'op': {'const': '-', 'enum': ['-'], 'title': 'Op', 'type': 'string'}, 'a': {'title': 'A', 'type': 'number'}, 'b': {'title': 'B', 'type': 'number'}}, 'required': ['op', 'a', 'b'], 'title': 'Sub', 'type': 'object'}}, 'anyOf': [{'$ref': '#/$defs/Add'}, {'$ref': '#/$defs/Sub'}, {'$ref': '#/$defs/Mul'}, {'$ref': '#/$defs/Div'}]}\n "}], + "lm_kwargs": {"temperature": 0.1}, "client": None}, + "contents": {} + } + consumes_data = [] + + input = { + "invocation": invocation_data, + "consumes": consumes_data + } + response = client.post( + "/invocation", + json=input + ) + + print(response.json()) + assert response.status_code == 200 + # assert response.json() == input + + +class MySampleToolInput(BaseModel): + sample_property: str = Field("A thing") + + +@ell.tool() +def my_sample_tool(args: MySampleToolInput = Field( + description="The full name of a city and country, e.g. San Francisco, CA, USA")): + return '42' + + +def test_invocation_json_round_trip(): + # pretend it's being tracked + my_sample_tool.__ell_hash__ = "lmp-123" + invocation_id = "invocation-" + uuid4().hex + tool_call = ToolCall( + tool=my_sample_tool, + tool_call_id=uuid4().hex, + params=MySampleToolInput(sample_property="test"), + ) + invocation_contents = InvocationContents( + invocation_id=invocation_id, + results=[Message(role='user', content=[tool_call])] + ) + invocation = Invocation( + id=invocation_id, + lmp_id=uuid4().hex, + latency_ms=42.0, + contents=invocation_contents, + created_at=utc_now() + ) + + # Serialize + result = invocation.model_dump() + + # Deserialize + _invocation=None + try: + _invocation = Invocation.model_validate(result) + except ValidationError as e: + import json + print("\nJSON errors:") + print(json.dumps(e.errors(), default=str,indent=2)) + + # Should be equal + # Except that: + # ToolCall before / after serialization: + # 1. `tool` is a function vs a string + # 2. `params` is a BaseModel (in userland) vs a dictionary + # These are not equivalent + + # What should be equivalent: deserialized forms of serialized forms + assert _invocation.model_dump() == result + +def test_write_invocation_tool_call(async_sqlite_serializer: AsyncSQLiteSerializer): + _app, client, *_ = create_test_app(async_sqlite_serializer) + # Test basic http functionality + client = client.client + + # first write an lmp.. + lmp_id = uuid4().hex + lmp_data: Dict[str, Any] = { + "lmp_id": lmp_id, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + "lmp_type": LMPType.LM, + "api_params": {"param1": "value1"}, + } + response = client.post( + "/lmp", + json=lmp_data + ) + try: + assert response.status_code == 200 + except Exception as e: + print(response.json()) + raise e + + # pretend it's being tracked + my_sample_tool.__ell_hash__ = "lmp-123" + invocation_id = "invocation-" + uuid4().hex + tool_call = ToolCall( + tool=my_sample_tool, + tool_call_id=uuid4().hex, + params=MySampleToolInput(sample_property="test"), + ) + invocation_contents = InvocationContents( + invocation_id=invocation_id, + results=[Message(role='user', content=[tool_call])] + ) + invocation = Invocation( + id=invocation_id, + lmp_id=lmp_id, + latency_ms=42.0, + contents=invocation_contents, + created_at=utc_now() + ) + + response = client.post( + "/invocation", + json={'invocation':invocation.model_dump(),'consumes':[]} + ) + print(response.json()) + assert response.status_code == 200 + +def test_http_client_write_lmp(async_sqlite_serializer: AsyncSQLiteSerializer): + _app, client, *_ = create_test_app(async_sqlite_serializer) + + lmp_data: Dict[str, Any] = { + "lmp_id": uuid4().hex, + "lmp_type": LMPType.LM, + "name": "Test LMP", + "source": "def test_function(): pass", + "dependencies": str(["dep1", "dep2"]), + } + result = client.write_lmp(WriteLMPInput( + lmp_id=lmp_data["lmp_id"], + lmp_type=lmp_data["lmp_type"], + name=lmp_data["name"], + source=lmp_data["source"], + dependencies=lmp_data["dependencies"], + )) + assert result is None + +def test_http_client_write_invocation(async_sqlite_serializer: AsyncSQLiteSerializer): + _app, client, *_ = create_test_app(async_sqlite_serializer) + + # Invocation depends on an lmp being written so write one first + lmp_id = uuid4().hex + + client.write_lmp(WriteLMPInput( + lmp_id=lmp_id, + name="Test LMP", + source="def test_function(): pass", + dependencies=str(["dep1", "dep2"]), + lmp_type=LMPType.LM, + )) + + invocation_id = uuid4().hex + result = client.write_invocation(WriteInvocationInput( + invocation=Invocation( + id=invocation_id, + lmp_id=lmp_id, + contents=InvocationContents( + invocation_id=invocation_id, + results=[Message(role='user', content="hello")] + ), + created_at=utc_now(), + latency_ms=42.0, + ), + consumes=[] + )) + assert result is None + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/test_message_type.py b/tests/test_message_type.py index 8fd476ad4..4e31d122f 100644 --- a/tests/test_message_type.py +++ b/tests/test_message_type.py @@ -1,3 +1,5 @@ +from uuid import uuid4 + import pytest from pydantic import BaseModel import ell @@ -143,4 +145,39 @@ def test_message_json_serialization(): assert loaded_message.role == original_message.role assert len(loaded_message.content) == len(original_message.content) - assert str(loaded_message.content[0].text) == str(original_message.content[0].text) \ No newline at end of file + assert str(loaded_message.content[0].text) == str(original_message.content[0].text) + +def test_tool_call_json_serialization(): + class MySampleToolInput(BaseModel): + sample_property: str + + @ell.tool() + def my_sample_tool(args: MySampleToolInput): + return '42' + + original_message = Message(role='assistant', content=[ + ToolCall( + tool=my_sample_tool, + tool_call_id=f'call_{uuid4().hex}', + params={'args': MySampleToolInput(sample_property="test")}, + )]) + + message_json = original_message.model_dump_json() + loaded_message = Message.model_validate_json(message_json) + assert loaded_message.tool_calls[0].params == {'args': {'sample_property': 'test'}} + + assert loaded_message.role == original_message.role + assert len(loaded_message.content) == len(original_message.content) + assert str(loaded_message.content[0].text) == str(original_message.content[0].text) + +def test_parsed_json_serialization(): + class DummyFormattedResponse(BaseModel): + field1: str + field2: int + + original_message = Message(role='assistant', content=[ContentBlock(parsed=DummyFormattedResponse(field1="test", field2=42))]) + message_json = original_message.model_dump_json() + loaded_message = Message.model_validate_json(message_json) + assert loaded_message.content[0].parsed == {'field1': 'test', 'field2': 42} + + \ No newline at end of file