diff --git a/.github/workflows/build_all.yml b/.github/workflows/build_all.yml new file mode 100644 index 0000000..42212d7 --- /dev/null +++ b/.github/workflows/build_all.yml @@ -0,0 +1,80 @@ +name: Build All Platforms + +on: + push: + branches: + - main + pull_request: + +jobs: + macos: + runs-on: macos-latest + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: x86_64-apple-darwin,aarch64-apple-darwin + + - name: Build macOS binaries + working-directory: modules/c-wrapper + run: | + cargo build --release --target x86_64-apple-darwin + cargo build --release --target aarch64-apple-darwin + + - name: Upload macOS artifacts + uses: actions/upload-artifact@v4 + with: + name: macos-binaries + path: modules/c-wrapper/target/*/release/ + + linux: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: x86_64-unknown-linux-gnu,aarch64-unknown-linux-gnu + + - name: Build Linux binaries + working-directory: modules/c-wrapper + run: | + cargo build --release --target x86_64-unknown-linux-gnu + cargo build --release --target aarch64-unknown-linux-gnu + + - name: Upload Linux artifacts + uses: actions/upload-artifact@v4 + with: + name: linux-binaries + path: modules/c-wrapper/target/*/release/ + + windows: + runs-on: windows-latest + steps: + - name: Check out code + uses: actions/checkout@v3 + + - name: Install Rust + uses: actions-rs/toolchain@v1 + with: + toolchain: stable + target: x86_64-pc-windows-gnu,aarch64-pc-windows-gnu + + - name: Build Windows binaries + working-directory: modules/c-wrapper + run: | + cargo build --release --target x86_64-pc-windows-gnu + cargo build --release --target aarch64-pc-windows-gnu + + - name: Upload Windows artifacts + uses: actions/upload-artifact@v4 + with: + name: windows-binaries + path: modules/c-wrapper/target/*/release/ diff --git a/.github/workflows/c_wrapper_unit_tests.yml b/.github/workflows/c_wrapper_unit_tests.yml new file mode 100644 index 0000000..d0ca5e3 --- /dev/null +++ b/.github/workflows/c_wrapper_unit_tests.yml @@ -0,0 +1,28 @@ +name: Run C-wrapper unit tests + +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + test_core: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - name: Set up Rust + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: stable + override: true + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.11' + + - name: Run Core Unit Tests + run: cd modules/c-wrapper/scripts && sh build-docker.sh diff --git a/.github/workflows/surrealml_core_onnx_test.yml b/.github/workflows/surrealml_core_onnx_test.yml index 03ab696..e95243e 100644 --- a/.github/workflows/surrealml_core_onnx_test.yml +++ b/.github/workflows/surrealml_core_onnx_test.yml @@ -13,37 +13,5 @@ jobs: with: fetch-depth: 0 - - name: Set up Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.11' - - - name: Pre-test Setup - run: | - python3 -m venv venv - source venv/bin/activate - pip install --upgrade pip - pip install -r requirements.txt - - # build the local version of the core module to be loaded into python - echo "Building local version of core module" - - pip install . - export PYTHONPATH="." - - python ./tests/scripts/ci_local_build.py - echo "Local build complete" - - # train the models for the tests - python ./tests/model_builder/onnx_assets.py - deactivate - - name: Run Core Unit Tests - run: cd modules/core && cargo test --features onnx-tests + run: cd modules/core && docker build -t rust-onnx-runtime . && docker run --rm rust-onnx-runtime cargo test --features onnx-tests diff --git a/.github/workflows/surrealml_core_tensorflow_test.yml b/.github/workflows/surrealml_core_tensorflow_test.yml index 355e754..4fa1a02 100644 --- a/.github/workflows/surrealml_core_tensorflow_test.yml +++ b/.github/workflows/surrealml_core_tensorflow_test.yml @@ -13,37 +13,5 @@ jobs: with: fetch-depth: 0 - - name: Set up Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.11' - - - name: Pre-test Setup - run: | - python3 -m venv venv - source venv/bin/activate - pip install --upgrade pip - pip install -r requirements.txt - - # build the local version of the core module to be loaded into python - echo "Building local version of core module" - - pip install . - export PYTHONPATH="." - - python ./tests/scripts/ci_local_build.py - echo "Local build complete" - - # train the models for the tests - python ./tests/model_builder/tensorflow_assets.py - deactivate - - name: Run Core Unit Tests - run: cd modules/core && cargo test --features tensorflow-tests + run: cd modules/core && docker build -t rust-onnx-runtime . && docker run --rm rust-onnx-runtime cargo test --features tensorflow-tests diff --git a/.github/workflows/surrealml_core_test.yml b/.github/workflows/surrealml_core_test.yml index be6e3d9..619bdfc 100644 --- a/.github/workflows/surrealml_core_test.yml +++ b/.github/workflows/surrealml_core_test.yml @@ -13,47 +13,5 @@ jobs: with: fetch-depth: 0 - - name: Set up Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.11' - - - name: Pre-test Setup - run: | - python3 -m venv venv - source venv/bin/activate - pip install --upgrade pip - pip install -r requirements.txt - - # build the local version of the core module to be loaded into python - echo "Building local version of core module" - - pip install . - export PYTHONPATH="." - - python ./tests/scripts/ci_local_build.py - echo "Local build complete" - - # train the models for the tests - python ./tests/model_builder/sklearn_assets.py - deactivate - - - name: Run Python Unit Tests - run: | - source venv/bin/activate - export PYTHONPATH="." - python tests/unit_tests/engine/test_sklearn.py - deactivate - - name: Run Core Unit Tests - run: cd modules/core && cargo test --features sklearn-tests - - - name: Run HTTP Transfer Tests - run: cargo test + run: cd modules/core && docker build -t rust-onnx-runtime . && docker run --rm rust-onnx-runtime cargo test --features sklearn-tests diff --git a/.github/workflows/surrealml_core_torch_test.yml b/.github/workflows/surrealml_core_torch_test.yml index 01b679f..23507f3 100644 --- a/.github/workflows/surrealml_core_torch_test.yml +++ b/.github/workflows/surrealml_core_torch_test.yml @@ -13,46 +13,5 @@ jobs: with: fetch-depth: 0 - - name: Set up Rust - uses: actions-rs/toolchain@v1 - with: - profile: minimal - toolchain: stable - override: true - - - name: Set up Python - uses: actions/setup-python@v2 - with: - python-version: '3.11' - - - name: Pre-test Setup - run: | - python3 -m venv venv - source venv/bin/activate - pip install --upgrade pip - pip install -r requirements.txt - - # build the local version of the core module to be loaded into python - echo "Building local version of core module" - - pip install . - export PYTHONPATH="." - - python ./tests/scripts/ci_local_build.py - echo "Local build complete" - - # train the models for the tests - python ./tests/model_builder/torch_assets.py - deactivate - - - name: Run Python Unit Tests - run: | - source venv/bin/activate - export PYTHONPATH="." - python tests/unit_tests/engine/test_torch.py - python tests/unit_tests/test_rust_adapter.py - python tests/unit_tests/test_surml_file.py - deactivate - - name: Run Core Unit Tests - run: cd modules/core && cargo test --features torch-tests + run: cd modules/core && docker build -t rust-onnx-runtime . && docker run --rm rust-onnx-runtime cargo test --features torch-tests diff --git a/.github/workflows/surrealml_deployment.yml b/.github/workflows/surrealml_deployment.yml index e5ad1fd..3338ccb 100644 --- a/.github/workflows/surrealml_deployment.yml +++ b/.github/workflows/surrealml_deployment.yml @@ -1,187 +1,53 @@ -name: cross-build - -on: - push: - branches: - - main - -env: - CARGO_TERM_COLOR: always - -jobs: - - wait-for-other-workflow: - name: Wait for Other Workflow - runs-on: ubuntu-latest - steps: - - name: Wait for Other Workflow to Complete - run: "echo 'Waiting for other workflow to complete...'" - - build: # Workflow credit to https://github.com/samuelcolvin/rtoml/blob/main/.github/workflows/ci.yml - if: github.ref == 'refs/heads/main' && github.event.pusher - name: > - build ${{ matrix.python-version }} on ${{ matrix.platform || matrix.os }} - (${{ matrix.alt_arch_name || matrix.arch }}) - strategy: - fail-fast: false - matrix: - os: [ubuntu, macos, windows] - python-version: ["cp310", "pp37", "pp38", "pp39"] - arch: [main, alt] - include: - - os: ubuntu - platform: linux - - os: windows - ls: dir - - os: macos - arch: alt - alt_arch_name: "arm64 universal2" - exclude: - - os: macos - python-version: "pp37" - arch: alt - - os: macos - python-version: "pp38" - arch: alt - - os: macos - python-version: "pp39" - arch: alt - runs-on: ${{ format('{0}-latest', matrix.os) }} - steps: - - uses: actions/checkout@v3 - - - name: set up python - uses: actions/setup-python@v4 - with: - python-version: "3.9" - - - name: set up rust - uses: dtolnay/rust-toolchain@stable - with: - toolchain: stable - - - name: install the onnx library - run: | - pip install -r requirements.txt - # pip install requests - # python get_latest_version.py - - - name: Setup Rust cache - uses: Swatinem/rust-cache@v2 - with: - key: ${{ matrix.alt_arch_name }} - - - run: rustup target add aarch64-apple-darwin - if: matrix.os == 'macos' - - - run: rustup toolchain install stable-i686-pc-windows-msvc - if: matrix.os == 'windows' - - - run: rustup target add i686-pc-windows-msvc - if: matrix.os == 'windows' - - - name: Get pip cache dir - id: pip-cache - if: matrix.os != 'windows' - run: | - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - - name: Get pip cache dir - id: pip-cache-win - if: matrix.os == 'windows' - run: | - "dir=$(pip cache dir)" >> $env:GITHUB_OUTPUT - - - name: Cache python dependencies - uses: actions/cache@v3 - with: - path: ${{ steps.pip-cache.outputs.dir || steps.pip-cache-win.outputs.dir }} - key: ${{ runner.os }}-pip-${{ matrix.python-version }} - - - name: install python dependencies - run: pip install -U setuptools wheel twine cibuildwheel platformdirs - - - name: Display cibuildwheel cache dir - id: cibuildwheel-cache - run: | - from platformdirs import user_cache_path - import os - with open(os.getenv('GITHUB_OUTPUT'), 'w') as f: - f.write(f"dir={str(user_cache_path(appname='cibuildwheel', appauthor='pypa'))}") - shell: python - - - name: Cache cibuildwheel tools - uses: actions/cache@v3 - with: - path: ${{ steps.cibuildwheel-cache.outputs.dir }} - key: ${{ runner.os }}-cibuildwheel-${{ matrix.python-version }} - - - name: Install LLVM and Clang # required for bindgen to work, see https://github.com/rust-lang/rust-bindgen/issues/1797 - uses: KyleMayes/install-llvm-action@v1 - if: runner.os == 'Windows' - with: - version: "11.0" - directory: ${{ runner.temp }}/llvm - - - name: Set LIBCLANG_PATH - run: echo "LIBCLANG_PATH=$((gcm clang).source -replace "clang.exe")" >> $env:GITHUB_ENV - if: runner.os == 'Windows' - - - name: build_sdist - if: matrix.os == 'ubuntu' && matrix.python-version == 'cp310' - run: | - pip install maturin build - python -m build --sdist -o wheelhouse - - name: build ${{ matrix.platform || matrix.os }} binaries - run: cibuildwheel --output-dir wheelhouse - env: - CIBW_BUILD: "${{ matrix.python-version }}-*" - # rust doesn't seem to be available for musl linux on i686 - CIBW_SKIP: "*-musllinux_i686" - # we build for "alt_arch_name" if it exists, else 'auto' - CIBW_ARCHS: ${{ matrix.alt_arch_name || 'auto' }} - CIBW_ENVIRONMENT: 'PATH="$HOME/.cargo/bin:$PATH" CARGO_TERM_COLOR="always"' - CIBW_ENVIRONMENT_WINDOWS: 'PATH="$UserProfile\.cargo\bin;$PATH"' - CIBW_BEFORE_BUILD: rustup show - CIBW_BEFORE_BUILD_LINUX: > - curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain=stable --profile=minimal -y && - rustup show - CIBW_BUILD_VERBOSITY: 1 - - - run: ${{ matrix.ls || 'ls -lh' }} wheelhouse/ - - - uses: actions/upload-artifact@v3 - with: - name: wheels - path: wheelhouse - - release: - if: github.ref == 'refs/heads/main' && github.event.pusher - needs: build - runs-on: ubuntu-latest - steps: - - uses: actions/download-artifact@v2 - with: - name: wheels - path: wheelhouse - - - name: Install twine - run: python -m pip install --upgrade twine - - - name: Create pypirc file - run: | - echo "[distutils]" > ~/.pypirc - echo "index-servers =" >> ~/.pypirc - echo " pypi" >> ~/.pypirc - echo "" >> ~/.pypirc - echo "[pypi]" >> ~/.pypirc - echo "username: __token__" >> ~/.pypirc - echo "password: \${{ secrets.PYPI_TOKEN }}" >> ~/.pypirc - env: - PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} - - - name: Publish to PyPI - run: twine upload wheelhouse/* - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} +# on: +# pull_request: +# types: [opened, reopened, synchronize] + +# jobs: +# release: +# name: Release - ${{ matrix.platform.os-name }} +# strategy: +# matrix: +# platform: +# - os-name: FreeBSD-x86_64 +# runs-on: ubuntu-20.04 +# target: x86_64-unknown-freebsd +# skip_tests: true + +# - os-name: Linux-x86_64 +# runs-on: ubuntu-20.04 +# target: x86_64-unknown-linux-musl + +# - os-name: Linux-aarch64 +# runs-on: ubuntu-20.04 +# target: aarch64-unknown-linux-musl + +# - os-name: Linux-riscv64 +# runs-on: ubuntu-20.04 +# target: riscv64gc-unknown-linux-gnu + +# - os-name: Windows-x86_64 +# runs-on: windows-latest +# target: x86_64-pc-windows-msvc + +# - os-name: macOS-x86_64 +# runs-on: macOS-latest +# target: x86_64-apple-darwin + +# # more targets here ... + +# runs-on: ${{ matrix.platform.runs-on }} +# steps: +# - name: Checkout +# uses: actions/checkout@v4 +# - name: Build binary +# uses: houseabsolute/actions-rust-cross@v0 +# with: +# command: ${{ matrix.platform.command }} +# target: ${{ matrix.platform.target }} +# args: "--locked --release" +# strip: true +# - name: Publish artifacts and release +# uses: houseabsolute/actions-rust-release@v0 +# with: +# executable-name: ubi +# target: ${{ matrix.platform.target }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index bc1a8fe..590879d 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ surrealml.egg-info/ .vscode/ ./modules/utils/target/ modules/core/target/ +modules/c-wrapper/target/ ./modules/onnx_driver/target/ modules/onnx_driver/target/ surrealdb_build/ @@ -27,3 +28,6 @@ surrealml/rust_surrealml.cpython-310-darwin.so ./modules/pipelines/runners/integrated_training_runner/run_env/ modules/pipelines/runners/integrated_training_runner/run_env/ modules/pipelines/data_access/target/ +clients/python/build-context/ +modules/c-wrapper/build-context/ +*.dylib \ No newline at end of file diff --git a/README.md b/README.md index bf8211e..4da64ba 100644 --- a/README.md +++ b/README.md @@ -14,6 +14,10 @@ SurrealML is a feature that allows you to store trained machine learning models 4. Python Environment Setup: A Python environment with necessary libraries installed, including SurrealML, PyTorch or SKLearn (depending on your model preference). 5. SurrealDB Installation: Ensure you have SurrealDB installed and running on your machine or server +## New Clients + +We are removing the `PyO3` bindings and just using raw C bindings for the `surrealml-core` library. This will simplfy builds and also enable clients in other languges to use the `surrealml-core` library. The `c-wrapper` module can be found in the `modules/c-wrapper` directory. The new clients can be found in the `clients` directory. + ## Installation To install SurrealML, make sure you have Python installed. Then, install the `SurrealML` library and either `PyTorch` or diff --git a/clients/python/Dockerfile b/clients/python/Dockerfile new file mode 100644 index 0000000..7a22021 --- /dev/null +++ b/clients/python/Dockerfile @@ -0,0 +1,36 @@ +# Use an official Rust image +FROM rust:1.83-slim + +# Install necessary tools +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Set the working directory +WORKDIR /app + +# Copy the project files into the container +COPY . . + +# Download ONNX Runtime 1.20.0 +RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz \ + && tar -xvf onnxruntime-linux-x64-1.20.0.tgz \ + && mv onnxruntime-linux-x64-1.20.0 /onnxruntime + +# Set the ONNX Runtime library path +# you need these environment variables to be able to link the onnxruntime to the c-lib +ENV ORT_LIB_LOCATION=/onnxruntime/lib +ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + + +# install python +RUN apt-get update && apt-get install -y python3 python3-pip +RUN apt install -y python3.11-venv +RUN python3 -m venv venv +RUN source venv/bin/activate && cd clients/python && pip install . + +CMD ["tail", "-f", "/dev/null"] + diff --git a/clients/python/README.md b/clients/python/README.md new file mode 100644 index 0000000..d8d27e5 --- /dev/null +++ b/clients/python/README.md @@ -0,0 +1,4 @@ + +# SurrealML Python Client + +The SurrealML Python client using the Rust `surrealml` library without any `PyO3` bindings. \ No newline at end of file diff --git a/clients/python/assets/linear.surml b/clients/python/assets/linear.surml new file mode 100644 index 0000000..f092b50 Binary files /dev/null and b/clients/python/assets/linear.surml differ diff --git a/clients/python/assets/load.py b/clients/python/assets/load.py new file mode 100644 index 0000000..41b1a78 --- /dev/null +++ b/clients/python/assets/load.py @@ -0,0 +1,9 @@ +from surrealml import SurMlFile, Engine + + +new_file = SurMlFile.load("./linear.surml", engine=Engine.PYTORCH) + +print(new_file.buffered_compute({ + "squarefoot": 1.0, + "num_floors": 2.0 +})) diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml new file mode 100644 index 0000000..de3e49b --- /dev/null +++ b/clients/python/pyproject.toml @@ -0,0 +1,2 @@ +[build-system] +requires = ["setuptools", "wheel", "build"] diff --git a/clients/python/requirements.txt b/clients/python/requirements.txt new file mode 100644 index 0000000..0700417 --- /dev/null +++ b/clients/python/requirements.txt @@ -0,0 +1,2 @@ +onnxruntime==1.17.3 +numpy==1.26.3 diff --git a/clients/python/scripts/build_docker.sh b/clients/python/scripts/build_docker.sh new file mode 100644 index 0000000..f6541b3 --- /dev/null +++ b/clients/python/scripts/build_docker.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +BUILD_DIR="build-context" + +if [ -d "$BUILD_DIR" ]; then + echo "Cleaning up existing build directory..." + rm -rf "$BUILD_DIR" +fi + +mkdir "$BUILD_DIR" +mkdir "$BUILD_DIR"/clients +mkdir "$BUILD_DIR"/clients/python +mkdir "$BUILD_DIR"/modules +mkdir "$BUILD_DIR"/modules/ +mkdir "$BUILD_DIR"/modules/ + +cp -r surrealml "$BUILD_DIR"/clients/python/surrealml +cp -r assets "$BUILD_DIR"/clients/python/assets +cp setup.py "$BUILD_DIR"/clients/python/setup.py +cp pyproject.toml "$BUILD_DIR"/clients/python/pyproject.toml + +cp Dockerfile "$BUILD_DIR"/Dockerfile + +cp -r ../../modules/c-wrapper/ "$BUILD_DIR"/modules/ +cp -r ../../modules/core/ "$BUILD_DIR"/modules/ +rm -rf "$BUILD_DIR"/modules/core/.git +rm -rf "$BUILD_DIR"/modules/c-wrapper/.git +rm -rf "$BUILD_DIR"/modules/core/target/ +rm -rf "$BUILD_DIR"/modules/c-wrapper/target/ +cd "$BUILD_DIR" +docker build --no-cache -t surrealml-python . + +docker run -it surrealml-python /bin/bash \ No newline at end of file diff --git a/clients/python/scripts/build_wheel.sh b/clients/python/scripts/build_wheel.sh new file mode 100644 index 0000000..8f1dc02 --- /dev/null +++ b/clients/python/scripts/build_wheel.sh @@ -0,0 +1,2 @@ + +python -m build \ No newline at end of file diff --git a/clients/python/scripts/download_onnx.sh b/clients/python/scripts/download_onnx.sh new file mode 100644 index 0000000..f1f641a --- /dev/null +++ b/clients/python/scripts/download_onnx.sh @@ -0,0 +1 @@ +#!/usr/bin/env bash diff --git a/clients/python/setup.py b/clients/python/setup.py new file mode 100644 index 0000000..a1f8e4d --- /dev/null +++ b/clients/python/setup.py @@ -0,0 +1,83 @@ +import os +import platform +import shutil +import subprocess +from pathlib import Path + +from setuptools import setup + + +def get_c_lib_name() -> str: + system = platform.system() + if system == "Linux": + return "libc_wrapper.so" + elif system == "Darwin": # macOS + return "libc_wrapper.dylib" + elif system == "Windows": + return "libc_wrapper.dll" + raise ValueError(f"Unsupported system: {system}") + +# define the paths to the C wrapper and root +DIR_PATH = Path(__file__).parent +ROOT_PATH = DIR_PATH.joinpath("..").joinpath("..") +C_PATH = ROOT_PATH.joinpath("modules").joinpath("c-wrapper") +BINARY_PATH = C_PATH.joinpath("target").joinpath("release").joinpath(get_c_lib_name()) +BINARY_DIST = DIR_PATH.joinpath("surrealml").joinpath(get_c_lib_name()) + +build_flag = False + +# build the C lib and copy it over to the python lib +if BINARY_DIST.exists() is False: + subprocess.Popen("cargo build --release", cwd=str(C_PATH), shell=True).wait() + shutil.copyfile(BINARY_PATH, BINARY_DIST) + build_flag = True + +setup( + name="surrealml", + version="0.1.0", + description="A machine learning package for interfacing with various frameworks.", + author="Maxwell Flitton", + author_email="maxwellflitton@gmail.com", + url="https://github.com/surrealdb/surrealml", + license="MIT", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + ], + python_requires=">=3.6", + install_requires=[ + "numpy==1.26.3", + ], + extras_require={ + "sklearn": [ + "skl2onnx==1.16.0", + "scikit-learn==1.4.0", + ], + "torch": [ + "torch==2.1.2", + ], + "tensorflow": [ + "tf2onnx==1.16.1", + "tensorflow==2.16.1", + ], + }, + packages=[ + "surrealml", + "surrealml.engine", + "surrealml.model_templates", + "surrealml.model_templates.datasets", + "surrealml.model_templates.sklearn", + "surrealml.model_templates.torch", + "surrealml.model_templates.tensorflow", + ], + package_data={ + "surrealml": ["libc_wrapper.so", "libc_wrapper.dylib", "libc_wrapper.dll"] + }, + include_package_data=True, + zip_safe=False, +) + +# cleanup after install +if build_flag is True: + os.remove(BINARY_DIST) diff --git a/clients/python/surrealml/VERSION.txt b/clients/python/surrealml/VERSION.txt new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/__init__.py b/clients/python/surrealml/__init__.py new file mode 100644 index 0000000..fcd35cc --- /dev/null +++ b/clients/python/surrealml/__init__.py @@ -0,0 +1,2 @@ +from surrealml.surml_file import SurMlFile +from surrealml.engine import Engine diff --git a/clients/python/surrealml/c_structs.py b/clients/python/surrealml/c_structs.py new file mode 100644 index 0000000..d8f442a --- /dev/null +++ b/clients/python/surrealml/c_structs.py @@ -0,0 +1,94 @@ +""" +Defines all the C structs that are returned from the C lib. +""" +from ctypes import Structure, c_char_p, c_int, c_size_t, POINTER, c_float, c_byte + + +class StringReturn(Structure): + """ + A return type that just returns a string + + Fields: + string: the string that is being returned (only present if successful) + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("string", c_char_p), # Corresponds to *mut c_char + ("is_error", c_int), # Corresponds to c_int + ("error_message", c_char_p) # Corresponds to *mut c_char + ] + +class EmptyReturn(Structure): + """ + A return type that just returns nothing + + Fields: + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("is_error", c_int), # Corresponds to c_int + ("error_message", c_char_p) # Corresponds to *mut c_char + ] + + +class FileInfo(Structure): + """ + A return type when loading the meta of a surml file. + + Fields: + file_id: a unique identifier for the file in the state of the C lib + name: a name of the model + description: a description of the model + error_message: the error message (only present if error) + is_error: 1 if error, 0 if not + """ + _fields_ = [ + ("file_id", c_char_p), # Corresponds to *mut c_char + ("name", c_char_p), # Corresponds to *mut c_char + ("description", c_char_p), # Corresponds to *mut c_char + ("version", c_char_p), # Corresponds to *mut c_char + ("error_message", c_char_p), # Corresponds to *mut c_char + ("is_error", c_int) # Corresponds to c_int + ] + + +class Vecf32Return(Structure): + """ + A return type when loading the meta of a surml vector. + + Fields: + data: the result of the ML execution + length: the length of the vector + capacity: the capacity of the vector + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("data", POINTER(c_float)), # Pointer to f32 array + ("length", c_size_t), # Length of the array + ("capacity", c_size_t), # Capacity of the array + ("is_error", c_int), # Indicates if it's an error + ("error_message", c_char_p), # Optional error message + ] + + +class VecU8Return(Structure): + """ + A return type returning bytes. + + Fields: + data: bytes + length: the length of the vector + capacity: the capacity of the vector + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("data", POINTER(c_byte)), # Pointer to bytes + ("length", c_size_t), # Length of the array + ("capacity", c_size_t), # Capacity of the array + ("is_error", c_int), # Indicates if it's an error + ("error_message", c_char_p), + ] diff --git a/clients/python/surrealml/engine/__init__.py b/clients/python/surrealml/engine/__init__.py new file mode 100644 index 0000000..05708ae --- /dev/null +++ b/clients/python/surrealml/engine/__init__.py @@ -0,0 +1,24 @@ +from enum import Enum + +from surrealml.engine.sklearn import SklearnOnnxAdapter +from surrealml.engine.torch import TorchOnnxAdapter +from surrealml.engine.tensorflow import TensorflowOnnxAdapter +from surrealml.engine.onnx import OnnxAdapter + + +class Engine(Enum): + """ + The Engine enum is used to specify the engine to use for a given model. + + Attributes: + PYTORCH: The PyTorch engine which will be PyTorch and ONNX. + NATIVE: The native engine which will be native rust and linfa. + SKLEARN: The sklearn engine which will be sklearn and ONNX + TENSOFRLOW: The TensorFlow engine which will be TensorFlow and ONNX + ONNX: The ONNX engine which bypasses the conversion to ONNX. + """ + PYTORCH = "pytorch" + NATIVE = "native" + SKLEARN = "sklearn" + TENSORFLOW = "tensorflow" + ONNX = "onnx" diff --git a/clients/python/surrealml/engine/onnx.py b/clients/python/surrealml/engine/onnx.py new file mode 100644 index 0000000..d904672 --- /dev/null +++ b/clients/python/surrealml/engine/onnx.py @@ -0,0 +1,26 @@ +""" +This file defines the adapter for the ONNX file format. This adapter does not convert anything as the input +model is already in the ONNX format. It simply saves the model to a file. However, I have added this adapter +to keep the same structure as the other adapters for different engines (maxwell flitton). +""" +from surrealml.engine.utils import create_file_cache_path + + +class OnnxAdapter: + + @staticmethod + def save_model_to_onnx(model, inputs) -> str: + """ + Saves a model to an onnx file. + + :param model: the raw ONNX model to directly save + :param inputs: the inputs to the model needed to trace the model + :return: the path to the cache created with a unique id to prevent collisions. + """ + file_path = create_file_cache_path() + + with open(file_path, "wb") as f: + f.write(model.SerializeToString()) + + return file_path + diff --git a/clients/python/surrealml/engine/sklearn.py b/clients/python/surrealml/engine/sklearn.py new file mode 100644 index 0000000..089bb57 --- /dev/null +++ b/clients/python/surrealml/engine/sklearn.py @@ -0,0 +1,42 @@ +""" +This file defines the adapter that converts an sklearn model to an onnx model and saves the onnx model to a file. +""" +try: + import skl2onnx +except ImportError: + skl2onnx = None + +from surrealml.engine.utils import create_file_cache_path + + +class SklearnOnnxAdapter: + """ + Converts and saves sklearn models to onnx format. + """ + + @staticmethod + def check_dependency() -> None: + """ + Checks if the sklearn dependency is installed raising an error if not. + Please call this function when performing any sklearn related operations. + """ + if skl2onnx is None: + raise ImportError("sklearn feature needs to be installed to use sklearn features") + + @staticmethod + def save_model_to_onnx(model, inputs) -> str: + """ + Saves a sklearn model to an onnx file. + + :param model: the sklearn model to convert. + :param inputs: the inputs to the model needed to trace the model + :return: the path to the cache created with a unique id to prevent collisions. + """ + SklearnOnnxAdapter.check_dependency() + file_path = create_file_cache_path() + onnx = skl2onnx.to_onnx(model, inputs) + + with open(file_path, "wb") as f: + f.write(onnx.SerializeToString()) + + return file_path diff --git a/clients/python/surrealml/engine/tensorflow.py b/clients/python/surrealml/engine/tensorflow.py new file mode 100644 index 0000000..05a2937 --- /dev/null +++ b/clients/python/surrealml/engine/tensorflow.py @@ -0,0 +1,45 @@ +import os +import shutil +try: + import tf2onnx + import tensorflow as tf +except ImportError: + tf2onnx = None + tf = None + +from surrealml.engine.utils import TensorflowCache + + +class TensorflowOnnxAdapter: + + @staticmethod + def check_dependency() -> None: + """ + Checks if the tensorflow dependency is installed raising an error if not. + Please call this function when performing any tensorflow related operations. + """ + if tf2onnx is None or tf is None: + raise ImportError("tensorflow feature needs to be installed to use tensorflow features") + + @staticmethod + def save_model_to_onnx(model, inputs) -> str: + """ + Saves a tensorflow model to an onnx file. + + :param model: the tensorflow model to convert. + :param inputs: the inputs to the model needed to trace the model + :return: the path to the cache created with a unique id to prevent collisions. + """ + TensorflowOnnxAdapter.check_dependency() + cache = TensorflowCache() + + model_file_path = cache.new_cache_path + onnx_file_path = cache.new_cache_path + + tf.saved_model.save(model, model_file_path) + + os.system( + f"python -m tf2onnx.convert --saved-model {model_file_path} --output {onnx_file_path}" + ) + shutil.rmtree(model_file_path) + return onnx_file_path diff --git a/clients/python/surrealml/engine/torch.py b/clients/python/surrealml/engine/torch.py new file mode 100644 index 0000000..38e097f --- /dev/null +++ b/clients/python/surrealml/engine/torch.py @@ -0,0 +1,35 @@ +try: + import torch +except ImportError: + torch = None + +from surrealml.engine.utils import create_file_cache_path + + +class TorchOnnxAdapter: + + @staticmethod + def check_dependency() -> None: + """ + Checks if the sklearn dependency is installed raising an error if not. + Please call this function when performing any sklearn related operations. + """ + if torch is None: + raise ImportError("torch feature needs to be installed to use torch features") + + @staticmethod + def save_model_to_onnx(model, inputs) -> str: + """ + Saves a torch model to an onnx file. + + :param model: the torch model to convert. + :param inputs: the inputs to the model needed to trace the model + :return: the path to the cache created with a unique id to prevent collisions. + """ + # the dynamic import it to prevent the torch dependency from being required for the whole package. + file_path = create_file_cache_path() + # below is to satisfy type checkers + if torch is not None: + traced_script_module = torch.jit.trace(model, inputs) + torch.onnx.export(traced_script_module, inputs, file_path) + return file_path diff --git a/clients/python/surrealml/engine/utils.py b/clients/python/surrealml/engine/utils.py new file mode 100644 index 0000000..dc413aa --- /dev/null +++ b/clients/python/surrealml/engine/utils.py @@ -0,0 +1,35 @@ +""" +This file contains utility functions for the engine. +""" +import os +import uuid + + +def create_file_cache_path(cache_folder: str = ".surmlcache") -> os.path: + """ + Creates a file cache path for the model (creating the file cache if not there). + + :return: the path to the cache created with a unique id to prevent collisions. + """ + if not os.path.exists(cache_folder): + os.makedirs(cache_folder) + unique_id = str(uuid.uuid4()) + file_name = f"{unique_id}.surml" + return os.path.join(cache_folder, file_name) + + +class TensorflowCache: + """ + A class to create a cache for tensorflow models. + + Attributes: + cache_path: The path to the cache created with a unique id to prevent collisions. + """ + def __init__(self) -> None: + create_file_cache_path() + self.cache_path = os.path.join(".surmlcache", "tensorflow") + create_file_cache_path(cache_folder=self.cache_path) + + @property + def new_cache_path(self) -> str: + return str(os.path.join(self.cache_path, str(uuid.uuid4()))) diff --git a/clients/python/surrealml/loader.py b/clients/python/surrealml/loader.py new file mode 100644 index 0000000..1852616 --- /dev/null +++ b/clients/python/surrealml/loader.py @@ -0,0 +1,112 @@ +""" +The loader for the dynamic C lib written in Rust. +""" +import ctypes +import platform +from pathlib import Path + +from surrealml.c_structs import EmptyReturn, StringReturn, Vecf32Return, FileInfo, VecU8Return + + +class Singleton(type): + """ + Ensures that the loader only loads once throughout the program's lifetime + """ + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +def load_library(lib_name: str = "libc_wrapper") -> ctypes.CDLL: + """ + Load the correct shared library based on the operating system. + + Args: + lib_name (str): The base name of the library without extension (e.g., "libc_wrapper"). + + Returns: + ctypes.CDLL: The loaded shared library. + """ + current_dir = Path(__file__).parent + system_name = platform.system() + + if system_name == "Windows": + lib_path = current_dir.joinpath(f"{lib_name}.dll") + elif system_name == "Darwin": # macOS + lib_path = current_dir.joinpath(f"{lib_name}.dylib") + elif system_name == "Linux": + lib_path = current_dir.joinpath(f"{lib_name}.so") + else: + raise OSError(f"Unsupported operating system: {system_name}") + + if not lib_path.exists(): + raise FileNotFoundError(f"Shared library not found at: {lib_path}") + + return ctypes.CDLL(str(lib_path)) + + +class LibLoader(metaclass=Singleton): + + def __init__(self, lib_name: str = "libc_wrapper") -> None: + """ + The constructor for the LibLoader class. + + args: + lib_name (str): The base name of the library without extension (e.g., "libc_wrapper"). + """ + self.lib = load_library(lib_name=lib_name) + functions = [ + self.lib.add_name, + self.lib.add_description, + self.lib.add_version, + self.lib.add_column, + self.lib.add_author, + self.lib.add_origin, + self.lib.add_engine, + ] + for i in functions: + i.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + i.restype = EmptyReturn + self.lib.load_model.restype = FileInfo + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.load_cached_raw_model.restype = StringReturn + self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p] + self.lib.to_bytes.argtypes = [ctypes.c_char_p] + self.lib.to_bytes.restype = VecU8Return + self.lib.save_model.restype = EmptyReturn + self.lib.save_model.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + self.lib.upload_model.argtypes = [ + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_size_t, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_char_p, + ctypes.c_char_p, + ] + self.lib.upload_model.restype = EmptyReturn + + # define the compute functions + self.lib.raw_compute.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t] + self.lib.raw_compute.restype = Vecf32Return + self.lib.buffered_compute.argtypes = [ + ctypes.c_char_p, # file_id_ptr -> *const c_char + ctypes.POINTER(ctypes.c_float), # data_ptr -> *const c_float + ctypes.c_size_t, # data_length -> usize + ctypes.POINTER(ctypes.c_char_p), # strings -> *const *const c_char + ctypes.c_int # string_count -> c_int + ] + self.lib.buffered_compute.restype = Vecf32Return + + # Define free alloc functions + self.lib.free_string_return.argtypes = [StringReturn] + self.lib.free_empty_return.argtypes = [EmptyReturn] + self.lib.free_vec_u8.argtypes = [VecU8Return] + self.lib.free_vecf32_return.argtypes = [Vecf32Return] + self.lib.free_file_info.argtypes = [FileInfo] + + + diff --git a/clients/python/surrealml/model_templates/__init__.py b/clients/python/surrealml/model_templates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/model_templates/datasets/__init__.py b/clients/python/surrealml/model_templates/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/model_templates/datasets/house_linear.py b/clients/python/surrealml/model_templates/datasets/house_linear.py new file mode 100644 index 0000000..e592c28 --- /dev/null +++ b/clients/python/surrealml/model_templates/datasets/house_linear.py @@ -0,0 +1,41 @@ +import numpy as np + + +raw_squarefoot = np.array([1000, 1200, 1500, 1800, 2000, 2200, 2500, 2800, 3000, 3200], dtype=np.float32) +raw_num_floors = np.array([1, 1, 1.5, 1.5, 2, 2, 2.5, 2.5, 3, 3], dtype=np.float32) +raw_house_price = np.array([200000, 230000, 280000, 320000, 350000, 380000, 420000, 470000, 500000, 520000], + dtype=np.float32) +squarefoot = (raw_squarefoot - raw_squarefoot.mean()) / raw_squarefoot.std() +num_floors = (raw_num_floors - raw_num_floors.mean()) / raw_num_floors.std() +house_price = (raw_house_price - raw_house_price.mean()) / raw_house_price.std() +inputs = np.column_stack((squarefoot, num_floors)) + + +HOUSE_LINEAR = { + "inputs": inputs, + "outputs": house_price, + + "squarefoot": squarefoot, + "num_floors": num_floors, + "input order": ["squarefoot", "num_floors"], + "raw_inputs": { + "squarefoot": raw_squarefoot, + "num_floors": raw_num_floors, + }, + "normalised_inputs": { + "squarefoot": squarefoot, + "num_floors": num_floors, + }, + "normalisers": { + "squarefoot": { + "type": "z_score", + "mean": squarefoot.mean(), + "std": squarefoot.std() + }, + "num_floors": { + "type": "z_score", + "mean": num_floors.mean(), + "std": num_floors.std() + } + }, +} diff --git a/clients/python/surrealml/model_templates/onnx/__init__.py b/clients/python/surrealml/model_templates/onnx/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/model_templates/onnx/onnx_linear.py b/clients/python/surrealml/model_templates/onnx/onnx_linear.py new file mode 100644 index 0000000..5c10edb --- /dev/null +++ b/clients/python/surrealml/model_templates/onnx/onnx_linear.py @@ -0,0 +1,44 @@ +""" +Trains a linear regression model using sklearn but keeping the ONNX format for the raw onnx support. +""" +from sklearn.linear_model import LinearRegression + +from surrealml.model_templates.datasets.house_linear import HOUSE_LINEAR + + +def train_model(): + """ + Trains a linear regression model using sklearn and returns the raw ONNX format. + This is a basic model that can be used for testing. + """ + import skl2onnx + model = LinearRegression() + model.fit(HOUSE_LINEAR["inputs"], HOUSE_LINEAR["outputs"]) + return skl2onnx.to_onnx(model, HOUSE_LINEAR["inputs"]) + + +def export_model_onnx(model): + """ + Exports the model to ONNX format. + + :param model: the model to export. + :return: the path to the exported model. + """ + return model + + +def export_model_surml(model): + """ + Exports the model to SURML format. + + :param model: the model to export. + :return: the path to the exported model. + """ + from surrealml import SurMlFile, Engine + file = SurMlFile(model=model, name="linear", inputs=HOUSE_LINEAR["inputs"], engine=Engine.ONNX) + file.add_column("squarefoot") + file.add_column("num_floors") + file.add_normaliser("squarefoot", "z_score", HOUSE_LINEAR["squarefoot"].mean(), HOUSE_LINEAR["squarefoot"].std()) + file.add_normaliser("num_floors", "z_score", HOUSE_LINEAR["num_floors"].mean(), HOUSE_LINEAR["num_floors"].std()) + file.add_output("house_price", "z_score", HOUSE_LINEAR["outputs"].mean(), HOUSE_LINEAR["outputs"].std()) + return file diff --git a/clients/python/surrealml/model_templates/sklearn/__init__.py b/clients/python/surrealml/model_templates/sklearn/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/model_templates/sklearn/sklearn_linear.py b/clients/python/surrealml/model_templates/sklearn/sklearn_linear.py new file mode 100644 index 0000000..e5f6896 --- /dev/null +++ b/clients/python/surrealml/model_templates/sklearn/sklearn_linear.py @@ -0,0 +1,43 @@ +""" +Trains a linear regression model using sklearn. This is a basic model that can be used for testing. +""" +from sklearn.linear_model import LinearRegression + +from surrealml.model_templates.datasets.house_linear import HOUSE_LINEAR + + +def train_model(): + """ + Trains a linear regression model using sklearn. This is a basic model that can be used for testing. + """ + model = LinearRegression() + model.fit(HOUSE_LINEAR["inputs"], HOUSE_LINEAR["outputs"]) + return model + + +def export_model_onnx(model): + """ + Exports the model to ONNX format. + + :param model: the model to export. + :return: the path to the exported model. + """ + import skl2onnx + return skl2onnx.to_onnx(model, HOUSE_LINEAR["inputs"]) + + +def export_model_surml(model): + """ + Exports the model to SURML format. + + :param model: the model to export. + :return: the path to the exported model. + """ + from surrealml import SurMlFile, Engine + file = SurMlFile(model=model, name="linear", inputs=HOUSE_LINEAR["inputs"], engine=Engine.SKLEARN) + file.add_column("squarefoot") + file.add_column("num_floors") + file.add_normaliser("squarefoot", "z_score", HOUSE_LINEAR["squarefoot"].mean(), HOUSE_LINEAR["squarefoot"].std()) + file.add_normaliser("num_floors", "z_score", HOUSE_LINEAR["num_floors"].mean(), HOUSE_LINEAR["num_floors"].std()) + file.add_output("house_price", "z_score", HOUSE_LINEAR["outputs"].mean(), HOUSE_LINEAR["outputs"].std()) + return file diff --git a/clients/python/surrealml/model_templates/tensorflow/__init__.py b/clients/python/surrealml/model_templates/tensorflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/model_templates/tensorflow/tensorflow_linear.py b/clients/python/surrealml/model_templates/tensorflow/tensorflow_linear.py new file mode 100644 index 0000000..d2ab2ce --- /dev/null +++ b/clients/python/surrealml/model_templates/tensorflow/tensorflow_linear.py @@ -0,0 +1,96 @@ +""" +Trains a linear regression model in TensorFlow. Should be used for testing certain processes +for linear regression and TensorFlow. +""" +import os +import shutil + +import tensorflow as tf + +from surrealml.model_templates.datasets.house_linear import HOUSE_LINEAR + + +class LinearModel(tf.Module): + def __init__(self, W, b): + super(LinearModel, self).__init__() + self.W = tf.Variable(W, dtype=tf.float32) + self.b = tf.Variable(b, dtype=tf.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=[None, 2], dtype=tf.float32)]) + def predict(self, x): + return tf.matmul(x, self.W) + self.b + + +def train_model(): + # Convert inputs and outputs to TensorFlow tensors + inputs = tf.constant(HOUSE_LINEAR["inputs"], dtype=tf.float32) + outputs = tf.constant(HOUSE_LINEAR["outputs"], dtype=tf.float32) + + # Model parameters + W = tf.Variable(tf.random.normal([2, 1]), name='weights') # Adjusted for two input features + b = tf.Variable(tf.zeros([1]), name='bias') + + # Training parameters + learning_rate = 0.01 + epochs = 100 + + # Training loop + for epoch in range(epochs): + with tf.GradientTape() as tape: + y_pred = tf.matmul(inputs, W) + b # Adjusted for matrix multiplication + loss = tf.reduce_mean(tf.square(y_pred - outputs)) + + gradients = tape.gradient(loss, [W, b]) + W.assign_sub(learning_rate * gradients[0]) + b.assign_sub(learning_rate * gradients[1]) + + if epoch % 10 == 0: # Print loss every 10 epochs + print(f"Epoch {epoch}: Loss = {loss.numpy()}") + + # Final parameters after training + final_W = W.numpy() + final_b = b.numpy() + + print(f"Trained W: {final_W}, Trained b: {final_b}") + return LinearModel(final_W, final_b) + + +def export_model_tf(model): + """ + Exports the model to TensorFlow SavedModel format. + """ + tf.saved_model.save(model, "linear_regression_model_tf") + return 'linear_regression_model_tf' + + +def export_model_onnx(model): + """ + Exports the model to ONNX format. + + :return: the path to the exported model. + """ + export_model_tf(model) + os.system("python -m tf2onnx.convert --saved-model linear_regression_model_tf --output model.onnx") + + with open("model.onnx", "rb") as f: + onnx_model = f.read() + shutil.rmtree("linear_regression_model_tf") + os.remove("model.onnx") + return onnx_model + + +def export_model_surml(model): + """ + Exports the model to SURML format. + + :param model: the model to export. + :return: the path to the exported model. + """ + from surrealml import SurMlFile, Engine + file = SurMlFile(model=model, name="linear", inputs=HOUSE_LINEAR["inputs"], engine=Engine.TENSORFLOW) + file.add_column("squarefoot") + file.add_column("num_floors") + file.add_normaliser("squarefoot", "z_score", HOUSE_LINEAR["squarefoot"].mean(), HOUSE_LINEAR["squarefoot"].std()) + file.add_normaliser("num_floors", "z_score", HOUSE_LINEAR["num_floors"].mean(), HOUSE_LINEAR["num_floors"].std()) + file.add_output("house_price", "z_score", HOUSE_LINEAR["outputs"].mean(), HOUSE_LINEAR["outputs"].std()) + return file diff --git a/clients/python/surrealml/model_templates/torch/__init__.py b/clients/python/surrealml/model_templates/torch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/clients/python/surrealml/model_templates/torch/torch_linear.py b/clients/python/surrealml/model_templates/torch/torch_linear.py new file mode 100644 index 0000000..e38c514 --- /dev/null +++ b/clients/python/surrealml/model_templates/torch/torch_linear.py @@ -0,0 +1,90 @@ +""" +Trains a linear regression model in torch. Should be used for testing certain processes +for linear regression and torch. +""" +import torch +import torch.nn as nn +import torch.optim as optim + +from surrealml.model_templates.datasets.house_linear import HOUSE_LINEAR + + +class LinearRegressionModel(nn.Module): + def __init__(self): + super(LinearRegressionModel, self).__init__() + self.linear = nn.Linear(2, 1) # 2 input features, 1 output + + def forward(self, x): + return self.linear(x) + + +def train_model(): + """ + Trains a linear regression model in torch. Should be used for testing certain processes. + """ + tensor = [ + torch.from_numpy(HOUSE_LINEAR["squarefoot"]), + torch.from_numpy(HOUSE_LINEAR["num_floors"]) + ] + X = torch.stack(tensor, dim=1) + + # Initialize the model + model = LinearRegressionModel() + + # Define the loss function and optimizer + criterion = nn.MSELoss() + optimizer = optim.SGD(model.parameters(), lr=0.01) + + num_epochs = 1000 + for epoch in range(num_epochs): + # Forward pass + y_pred = model(X) + + # Compute the loss + loss = criterion(y_pred.squeeze(), torch.from_numpy(HOUSE_LINEAR["outputs"])) + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + + test_squarefoot = torch.tensor([2800, 3200], dtype=torch.float32) + test_num_floors = torch.tensor([2.5, 3], dtype=torch.float32) + x = torch.stack([test_squarefoot, test_num_floors], dim=1) + return model, x + + +def export_model_onnx(model): + """ + Exports the model to ONNX format. + """ + tensor = [ + torch.from_numpy(HOUSE_LINEAR["squarefoot"]), + torch.from_numpy(HOUSE_LINEAR["num_floors"]) + ] + inputs = torch.stack(tensor, dim=1) + return torch.jit.trace(model, inputs) + + +def export_model_surml(model): + """ + Exports the model to SURML format. + + :param model: the model to export. + :return: the path to the exported model. + """ + from surrealml import SurMlFile, Engine + + tensor = [ + torch.from_numpy(HOUSE_LINEAR["squarefoot"]), + torch.from_numpy(HOUSE_LINEAR["num_floors"]) + ] + inputs = torch.stack(tensor, dim=1) + + file = SurMlFile(model=model, name="linear", inputs=inputs[:1], engine=Engine.PYTORCH) + file.add_column("squarefoot") + file.add_column("num_floors") + file.add_normaliser("squarefoot", "z_score", HOUSE_LINEAR["squarefoot"].mean(), HOUSE_LINEAR["squarefoot"].std()) + file.add_normaliser("num_floors", "z_score", HOUSE_LINEAR["num_floors"].mean(), HOUSE_LINEAR["num_floors"].std()) + file.add_output("house_price", "z_score", HOUSE_LINEAR["outputs"].mean(), HOUSE_LINEAR["outputs"].std()) + return file diff --git a/clients/python/surrealml/rust_adapter.py b/clients/python/surrealml/rust_adapter.py new file mode 100644 index 0000000..e570bfc --- /dev/null +++ b/clients/python/surrealml/rust_adapter.py @@ -0,0 +1,358 @@ +""" +The adapter to interact with the Rust module compiled to a C dynamic library +""" +import ctypes +import platform +import warnings +from pathlib import Path +from typing import List, Tuple +from typing import Optional + +from surrealml.c_structs import EmptyReturn, StringReturn, Vecf32Return, FileInfo, VecU8Return +from surrealml.engine import Engine +from surrealml.loader import LibLoader + + +def load_library(lib_name: str = "libc_wrapper") -> ctypes.CDLL: + """ + Load the correct shared library based on the operating system. + + Args: + lib_name (str): The base name of the library without extension (e.g., "libc_wrapper"). + + Returns: + ctypes.CDLL: The loaded shared library. + """ + current_dir = Path(__file__).parent + system_name = platform.system() + + if system_name == "Windows": + lib_path = current_dir.joinpath(f"{lib_name}.dll") + elif system_name == "Darwin": # macOS + lib_path = current_dir.joinpath(f"{lib_name}.dylib") + elif system_name == "Linux": + lib_path = current_dir.joinpath(f"{lib_name}.so") + else: + raise OSError(f"Unsupported operating system: {system_name}") + + if not lib_path.exists(): + raise FileNotFoundError(f"Shared library not found at: {lib_path}") + + return ctypes.CDLL(str(lib_path)) + + +class RustAdapter: + + def __init__(self, file_id: str, engine: Engine) -> None: + self.file_id: str = file_id + self.engine: Engine = engine + self.loader = LibLoader() + + @staticmethod + def pass_raw_model_into_rust(file_path: str) -> str: + """ + Points to a raw ONNX file and passes it into the rust library so it can be loaded + and tagged with a unique id so the Rust library can reference this model again + from within the rust library. + + :param file_path: the path to the raw ONNX file. + + :return: the unique id of the model. + """ + c_path = file_path.encode("utf-8") + loader = LibLoader() + outcome: StringReturn = loader.lib.load_cached_raw_model(c_path) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + file_path = outcome.string.decode("utf-8") + loader.lib.free_string_return(outcome) + return file_path + + def add_column(self, name: str) -> None: + """ + Adds a column to the model to the metadata (this needs to be called in order of the columns). + + :param name: the name of the column. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_column( + self.file_id.encode("utf-8"), + name.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def add_output(self, output_name: str, normaliser_type: str, one: float, two: float) -> None: + """ + Adds an output to the model to the metadata. + :param output_name: the name of the output. + :param normaliser_type: the type of normaliser to use. + :param one: the first parameter of the normaliser. + :param two: the second parameter of the normaliser. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_output( + self.file_id.encode("utf-8"), + output_name.encode("utf-8"), + normaliser_type.encode("utf-8"), + str(one).encode("utf-8"), + str(two).encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def add_description(self, description: str) -> None: + """ + Adds a description to the model to the metadata. + + :param description: the description of the model. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_description( + self.file_id.encode("utf-8"), + description.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def add_version(self, version: str) -> None: + """ + Adds a version to the model to the metadata. + + :param version: the version of the model. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_version( + self.file_id.encode("utf-8"), + version.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def add_name(self, name: str) -> None: + """ + Adds a name to the model to the metadata. + + :param name: the version of the model. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_name( + self.file_id.encode("utf-8"), + name.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def add_normaliser(self, column_name, normaliser_type, one, two) -> None: + """ + Adds a normaliser to the model to the metadata for a column. + + :param column_name: the name of the column (column already needs to be in the metadata to create mapping) + :param normaliser_type: the type of normaliser to use. + :param one: the first parameter of the normaliser. + :param two: the second parameter of the normaliser. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_normaliser( + self.file_id.encode("utf-8"), + column_name.encode("utf-8"), + normaliser_type.encode("utf-8"), + str(one).encode("utf-8"), + str(two).encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def add_author(self, author: str) -> None: + """ + Adds an author to the model to the metadata. + + :param author: the author of the model. + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_author( + self.file_id.encode("utf-8"), + author.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def save(self, path: str, name: Optional[str]) -> None: + """ + Saves the model to a file. + + :param path: the path to save the model to. + :param name: the name of the model. + + :return: None + """ + outcome: EmptyReturn = self.loader.lib.add_engine( + self.file_id.encode("utf-8"), + self.engine.value.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + outcome: EmptyReturn = self.loader.lib.add_origin( + self.file_id.encode("utf-8"), + "local".encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + if name is not None: + outcome: EmptyReturn = self.loader.lib.add_name( + self.file_id.encode("utf-8"), + name.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + else: + warnings.warn( + "You are saving a model without a name, you will not be able to upload this model to the database" + ) + outcome: EmptyReturn = self.loader.lib.save_model( + path.encode("utf-8"), + self.file_id.encode("utf-8") + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + self.loader.lib.free_empty_return(outcome) + + def to_bytes(self) -> bytes: + """ + Converts the model to bytes. + + :return: the model as bytes. + """ + outcome: VecU8Return = self.loader.lib.to_bytes( + self.file_id.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + byte_vec = outcome.data + self.loader.lib.free_vec_u8(outcome) + return byte_vec + + @staticmethod + def load(path) -> Tuple[str, str, str, str]: + """ + Loads a model from a file. + + :param path: the path to load the model from. + :return: the id of the model being loaded. + """ + loader = LibLoader() + outcome: FileInfo = loader.lib.load_model( + path.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + package = ( + outcome.file_id.decode("utf-8"), + outcome.name.decode("utf-8"), + outcome.description.decode("utf-8"), + outcome.version.decode("utf-8"), + ) + loader.lib.free_file_info(outcome) + return package + + @staticmethod + def upload( + path: str, + url: str, + chunk_size: int, + namespace: str, + database: str, + username: Optional[str] = None, + password: Optional[str] = None + ) -> None: + """ + Uploads a model to a remote server. + + :param path: the path to load the model from. + :param url: the url of the remote server. + :param chunk_size: the size of each chunk to upload. + :param namespace: the namespace of the remote server. + :param database: the database of the remote server. + :param username: the username of the remote server. + :param password: the password of the remote server. + + :return: None + """ + loader: EmptyReturn = LibLoader() + outcome = loader.lib.upload_model( + path.encode("utf-8"), + url.encode("utf-8"), + chunk_size, + namespace.encode("utf-8"), + database.encode("utf-8"), + username.encode("utf-8"), + password.encode("utf-8"), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + loader.lib.free_empty_return(outcome) + + def raw_compute(self, input_vector, dims=None) -> List[float]: + """ + Calculates an output from the model given an input vector. + + :param input_vector: a 1D vector of inputs to the model. + :param dims: the dimensions of the input vector to be sliced into + :return: the output of the model. + """ + array_type = ctypes.c_float * len(input_vector) + input_data = array_type(*input_vector) + outcome: Vecf32Return = self.loader.lib.raw_compute( + self.file_id.encode("utf-8"), + input_data, + len(input_data), + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + package = [outcome.data[i] for i in range(outcome.length)] + self.loader.lib.free_vecf32_return(outcome) + return package + + def buffered_compute(self, value_map: dict) -> List[float]: + """ + Calculates an output from the model given a value map. + + :param value_map: a dictionary of inputs to the model with the column names as keys and floats as values. + :return: the output of the model. + """ + string_buffer = [] + data_buffer = [] + for key, value in value_map.items(): + string_buffer.append(key.encode('utf-8')) + data_buffer.append(value) + + # Prepare input data as a ctypes array + array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size + input_data = array_type(*data_buffer) # Instantiate the array with the list elements + + # prepare the input strings + string_array = (ctypes.c_char_p * len(string_buffer))(*string_buffer) + string_count = len(string_buffer) + + outcome = self.loader.lib.buffered_compute( + self.file_id.encode("utf-8"), + input_data, + len(input_data), + string_array, + string_count + ) + if outcome.is_error == 1: + raise RuntimeError(outcome.error_message.decode("utf-8")) + return_data = [outcome.data[i] for i in range(outcome.length)] + self.loader.lib.free_vecf32_return(outcome) + return return_data diff --git a/clients/python/surrealml/surml_file.py b/clients/python/surrealml/surml_file.py new file mode 100644 index 0000000..fe89676 --- /dev/null +++ b/clients/python/surrealml/surml_file.py @@ -0,0 +1,224 @@ +""" +Defines the SurMlFile class which is used to save/load models and perform computations based on those models. +""" +from typing import Optional + +from surrealml.engine import Engine, SklearnOnnxAdapter, TorchOnnxAdapter, TensorflowOnnxAdapter, OnnxAdapter +from surrealml.rust_adapter import RustAdapter + + +class SurMlFile: + + def __init__(self, model=None, name=None, inputs=None, engine=None): + """ + The constructor for the SurMlFile class. + + :param model: the model to be saved. + :param name: the name of the model. + :param inputs: the inputs to the model needed to trace the model so the model can be saved. + :param sklearn: whether the model is an sklearn model or not. + """ + self.model = model + self.name = name + self.inputs = inputs + self.engine = engine + self.file_id = self._cache_model() + self.rust_adapter = RustAdapter(self.file_id, self.engine) + # below is optional metadata that can be added to the model through functions of the SurMlFile class + self.description = None + self.version = None + self.author = None + + def _cache_model(self) -> Optional[str]: + """ + Caches a model, so it can be loaded as raw bytes to be fused with the header. + + :return: the file id of the model so it can be retrieved from the cache. + """ + # This is triggered when the model is loaded from a file as we are not passing in a model + if self.model is None and self.name is None and self.inputs is None and self.engine is None: + return None + + if self.engine == Engine.SKLEARN: + raw_file_path: str = SklearnOnnxAdapter.save_model_to_onnx( + model=self.model, + inputs=self.inputs + ) + elif self.engine == Engine.PYTORCH: + raw_file_path: str = TorchOnnxAdapter.save_model_to_onnx( + model=self.model, + inputs=self.inputs + ) + elif self.engine == Engine.TENSORFLOW: + raw_file_path: str = TensorflowOnnxAdapter.save_model_to_onnx( + model=self.model, + inputs=self.inputs + ) + # Below doesn't really convert to ONNX, but I want to keep the same structure as the other engines + # (maxwell flitton) + elif self.engine == Engine.ONNX: + raw_file_path: str = OnnxAdapter.save_model_to_onnx( + model=self.model, + inputs=self.inputs + ) + else: + raise ValueError(f"Engine {self.engine} not supported") + return RustAdapter.pass_raw_model_into_rust(raw_file_path) + + def add_column(self, name): + """ + Adds a column to the model to the metadata (this needs to be called in order of the columns). + + :param name: the name of the column. + :return: None + """ + self.rust_adapter.add_column(name=name) + + def add_output(self, output_name, normaliser_type, one, two): + """ + Adds an output to the model to the metadata. + :param output_name: the name of the output. + :param normaliser_type: the type of normaliser to use. + :param one: the first parameter of the normaliser. + :param two: the second parameter of the normaliser. + :return: None + """ + self.rust_adapter.add_output(output_name, normaliser_type, one, two) + + def add_description(self, description: str) -> None: + """ + Adds a description to the model to the metadata. + + :param description: the description of the model. + :return: None + """ + self.description = description + self.rust_adapter.add_description(description) + + def add_version(self, version: str) -> None: + """ + Adds a version to the model to the metadata. + + :param version: the version of the model. + :return: None + """ + self.version = version + self.rust_adapter.add_version(version) + + def add_name(self, name: str) -> None: + """ + Adds a name to th model to the metadata. + + :param name: the name of the model. + :return: None + """ + self.name = name + self.rust_adapter.add_name(name) + + def add_normaliser(self, column_name, normaliser_type, one, two): + """ + Adds a normaliser to the model to the metadata for a column. + + :param column_name: the name of the column (column already needs to be in the metadata to create mapping) + :param normaliser_type: the type of normaliser to use. + :param one: the first parameter of the normaliser. + :param two: the second parameter of the normaliser. + :return: None + """ + self.rust_adapter.add_normaliser(column_name, normaliser_type, one, two) + + def add_author(self, author): + """ + Adds an author to the model to the metadata. + + :param author: the author of the model. + :return: None + """ + self.rust_adapter.add_author(author) + + def save(self, path): + """ + Saves the model to a file. + + :param path: the path to save the model to. + :return: None + """ + # right now the only engine is pytorch so we can hardcode it but when we add more engines we will need to + # add a parameter to the save function to specify the engine + self.rust_adapter.save(path=path, name=self.name) + + def to_bytes(self): + """ + Converts the model to bytes. + + :return: the model as bytes. + """ + return self.rust_adapter.to_bytes() + + @staticmethod + def load(path, engine: Engine): + """ + Loads a model from a file so compute operations can be done. + + :param path: the path to load the model from. + :param engine: the engine to use to load the model. + + :return: The SurMlFile with loaded model and engine definition + """ + self = SurMlFile() + self.file_id, self.name, self.description, self.version = self.rust_adapter.load(path) + self.engine = engine + self.rust_adapter = RustAdapter(self.file_id, self.engine) + return self + + @staticmethod + def upload( + path: str, + url: str, + chunk_size: int, + namespace: str, + database: str, + username: Optional[str] = None, + password: Optional[str] = None + ) -> None: + """ + Uploads a model to a remote server. + + :param path: the path to load the model from. + :param url: the url of the remote server. + :param chunk_size: the size of each chunk to upload. + :param namespace: the namespace of the remote server. + :param database: the database of the remote server. + :param username: the username of the remote server. + :param password: the password of the remote server. + + :return: None + """ + RustAdapter.upload( + path, + url, + chunk_size, + namespace, + database, + username, + password + ) + + def raw_compute(self, input_vector, dims=None): + """ + Calculates an output from the model given an input vector. + + :param input_vector: a 1D vector of inputs to the model. + :param dims: the dimensions of the input vector to be sliced into + :return: the output of the model. + """ + return self.rust_adapter.raw_compute(input_vector, dims) + + def buffered_compute(self, value_map): + """ + Calculates an output from the model given a value map. + + :param value_map: a dictionary of inputs to the model with the column names as keys and floats as values. + :return: the output of the model. + """ + return self.rust_adapter.buffered_compute(value_map) diff --git a/install/install_ml_c_lib.sh b/install/install_ml_c_lib.sh new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/.gitignore b/modules/c-wrapper/.gitignore new file mode 100644 index 0000000..e0d4e69 --- /dev/null +++ b/modules/c-wrapper/.gitignore @@ -0,0 +1 @@ +onnx_lib/ diff --git a/modules/c-wrapper/Cargo.toml b/modules/c-wrapper/Cargo.toml new file mode 100644 index 0000000..cfe851c --- /dev/null +++ b/modules/c-wrapper/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "c-wrapper" +version = "0.1.0" +edition = "2021" + +[dependencies] +surrealml-core = { path = "../core", features = ["dynamic"] } +uuid = { version = "1.11.1", features = ["v4"] } +ndarray = "0.16.1" + +# for the uploading the model to the server +tokio = { version = "1.43.0", features = ["full"] } +hyper = { version = "0.14.27", features = ["full"] } +base64 = "0.13" + +[lib] +crate-type = ["cdylib"] + +[build-dependencies] +reqwest = { version = "0.12.12", features = ["blocking", "json"] } +# tokio = { version = "1", features = ["full"] } # Required for reqwest +tar = "0.4" # For extracting tar files +flate2 = "1.0" # For handling gzip +zip = "2.2.2" diff --git a/modules/c-wrapper/Dockerfile b/modules/c-wrapper/Dockerfile new file mode 100644 index 0000000..95a60c7 --- /dev/null +++ b/modules/c-wrapper/Dockerfile @@ -0,0 +1,47 @@ +# Use an official Rust image +FROM rust:1.83-slim + +# Install necessary tools +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + pkg-config \ + ca-certificates \ + curl \ + gnupg \ + lsb-release \ + vim \ + && rm -rf /var/lib/apt/lists/* + +# Set the working directory +WORKDIR /app + +# Copy the project files into the container +COPY . . + +# Download ONNX Runtime 1.20.0 +# RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz \ +# && tar -xvf onnxruntime-linux-x64-1.20.0.tgz \ +# && mv onnxruntime-linux-x64-1.20.0 /onnxruntime + +# Set the ONNX Runtime library path +# ENV ORT_LIB_LOCATION=/onnxruntime/lib +# ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +# Set the ONNX Runtime library path +# ENV ORT_LIB_LOCATION=$(pwd)/c-wrapper/tests/test_utils/onnxruntime/lib +# ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +# install python for the tests +RUN apt-get update && apt-get install -y python3 python3-pip + +# Clean and build the Rust project +# RUN cd c-wrapper/scripts && bash prep_tests.sh +# RUN cd c-wrapper && cargo build --verbose > build_log.txt 2>&1 +RUN cd c-wrapper && cargo build && bash scripts/copy_over_lib.sh + +# RUN rm /onnxruntime + +# Run the tests +CMD ["bash", "c-wrapper/scripts/run_tests.sh"] diff --git a/modules/c-wrapper/build-context/Dockerfile b/modules/c-wrapper/build-context/Dockerfile new file mode 100644 index 0000000..95a60c7 --- /dev/null +++ b/modules/c-wrapper/build-context/Dockerfile @@ -0,0 +1,47 @@ +# Use an official Rust image +FROM rust:1.83-slim + +# Install necessary tools +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + pkg-config \ + ca-certificates \ + curl \ + gnupg \ + lsb-release \ + vim \ + && rm -rf /var/lib/apt/lists/* + +# Set the working directory +WORKDIR /app + +# Copy the project files into the container +COPY . . + +# Download ONNX Runtime 1.20.0 +# RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz \ +# && tar -xvf onnxruntime-linux-x64-1.20.0.tgz \ +# && mv onnxruntime-linux-x64-1.20.0 /onnxruntime + +# Set the ONNX Runtime library path +# ENV ORT_LIB_LOCATION=/onnxruntime/lib +# ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +# Set the ONNX Runtime library path +# ENV ORT_LIB_LOCATION=$(pwd)/c-wrapper/tests/test_utils/onnxruntime/lib +# ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +# install python for the tests +RUN apt-get update && apt-get install -y python3 python3-pip + +# Clean and build the Rust project +# RUN cd c-wrapper/scripts && bash prep_tests.sh +# RUN cd c-wrapper && cargo build --verbose > build_log.txt 2>&1 +RUN cd c-wrapper && cargo build && bash scripts/copy_over_lib.sh + +# RUN rm /onnxruntime + +# Run the tests +CMD ["bash", "c-wrapper/scripts/run_tests.sh"] diff --git a/modules/c-wrapper/build-context/c-wrapper/Cargo.toml b/modules/c-wrapper/build-context/c-wrapper/Cargo.toml new file mode 100644 index 0000000..cfe851c --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "c-wrapper" +version = "0.1.0" +edition = "2021" + +[dependencies] +surrealml-core = { path = "../core", features = ["dynamic"] } +uuid = { version = "1.11.1", features = ["v4"] } +ndarray = "0.16.1" + +# for the uploading the model to the server +tokio = { version = "1.43.0", features = ["full"] } +hyper = { version = "0.14.27", features = ["full"] } +base64 = "0.13" + +[lib] +crate-type = ["cdylib"] + +[build-dependencies] +reqwest = { version = "0.12.12", features = ["blocking", "json"] } +# tokio = { version = "1", features = ["full"] } # Required for reqwest +tar = "0.4" # For extracting tar files +flate2 = "1.0" # For handling gzip +zip = "2.2.2" diff --git a/modules/c-wrapper/build-context/c-wrapper/scripts/build-docker.sh b/modules/c-wrapper/build-context/c-wrapper/scripts/build-docker.sh new file mode 100644 index 0000000..7ffdba6 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/scripts/build-docker.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# wipe and build the build context +BUILD_DIR="build-context" +if [ -d "$BUILD_DIR" ]; then + echo "Cleaning up existing build directory..." + rm -rf "$BUILD_DIR" +fi +mkdir "$BUILD_DIR" +mkdir "$BUILD_DIR"/c-wrapper + +# copy over the code to be built +cp -r src "$BUILD_DIR"/c-wrapper/src +cp -r tests "$BUILD_DIR"/c-wrapper/tests +cp -r scripts "$BUILD_DIR"/c-wrapper/scripts +cp Cargo.toml "$BUILD_DIR"/c-wrapper/Cargo.toml +cp build.rs "$BUILD_DIR"/c-wrapper/build.rs +cp -r ../core "$BUILD_DIR"/core +cp Dockerfile "$BUILD_DIR"/Dockerfile + +# remove unnecessary files +rm -rf "$BUILD_DIR"/core/.git +rm -rf "$BUILD_DIR"/core/target/ + +# build the docker image +cd "$BUILD_DIR" +docker build --no-cache -t c-wrapper-tests . + +docker run c-wrapper-tests +# docker run -it c-wrapper-tests /bin/bash diff --git a/modules/c-wrapper/build-context/c-wrapper/scripts/prep_tests.sh b/modules/c-wrapper/build-context/c-wrapper/scripts/prep_tests.sh new file mode 100644 index 0000000..9bcd32a --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/scripts/prep_tests.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# download onnxruntime +# Detect operating system +OS=$(uname -s | tr '[:upper:]' '[:lower:]') + +# Detect architecture +ARCH=$(uname -m) + +# Download the correct onnxruntime +if [ "$ARCH" == "x86_64" ] && [ "$OS" == "linux" ]; then + wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz + tar -xvf onnxruntime-linux-x64-1.20.0.tgz + mv onnxruntime-linux-x64-1.20.0 tests/test_utils/onnxruntime +else + echo "Unsupported operating system and arch: $OS $ARCH" + exit 1 +fi + +export ORT_LIB_LOCATION=$(pwd)/tests/test_utils/onnxruntime/lib +export LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +cargo build + +# Get the operating system +OS=$(uname) + +# Set the library name and extension based on the OS +case "$OS" in + "Linux") + LIB_NAME="libc_wrapper.so" + ;; + "Darwin") + LIB_NAME="libc_wrapper.dylib" + ;; + "CYGWIN"*|"MINGW"*) + LIB_NAME="libc_wrapper.dll" + ;; + *) + echo "Unsupported operating system: $OS" + exit 1 + ;; +esac + +# Source directory (where Cargo outputs the compiled library) +SOURCE_DIR="target/debug" + +# Destination directory (tests directory) +DEST_DIR="tests/test_utils" + + +# Copy the library to the tests directory +if [ -f "$SOURCE_DIR/$LIB_NAME" ]; then + cp "$SOURCE_DIR/$LIB_NAME" "$DEST_DIR/" + echo "Copied $LIB_NAME to $DEST_DIR" +else + echo "Library not found: $SOURCE_DIR/$LIB_NAME" + exit 1 +fi diff --git a/modules/c-wrapper/build-context/c-wrapper/scripts/run_tests.sh b/modules/c-wrapper/build-context/c-wrapper/scripts/run_tests.sh new file mode 100644 index 0000000..cfe54f0 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/scripts/run_tests.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +cd tests + +python3 -m unittest discover . diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/execution/buffered_compute.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/execution/buffered_compute.rs new file mode 100644 index 0000000..4ac7e5e --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/execution/buffered_compute.rs @@ -0,0 +1,164 @@ +//! This module contains the buffered_compute function that is called from the C API to compute the model. +use crate::state::STATE; +use std::ffi::{c_float, CStr, CString, c_int, c_char}; +use surrealml_core::execution::compute::ModelComputation; +use crate::utils::Vecf32Return; +use std::collections::HashMap; + + +/// Computes the model with the given data. +/// +/// # Arguments +/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model. +/// * `data_ptr` - A pointer to the data to compute. +/// * `length` - The length of the data. +/// * `strings` - A pointer to an array of strings to use as keys for the data. +/// * `string_count` - The number of strings in the array. +/// +/// # Returns +/// A Vecf32Return object containing the outcome of the computation. +#[no_mangle] +pub extern "C" fn buffered_compute( + file_id_ptr: *const c_char, + data_ptr: *const c_float, + data_length: usize, + strings: *const *const c_char, + string_count: c_int +) -> Vecf32Return { + if file_id_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("File id is null").unwrap().into_raw() + } + } + if data_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("Data is null").unwrap().into_raw() + } + } + + let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() { + Ok(file_id) => file_id.to_owned(), + Err(error) => return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw() + } + }; + + if strings.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("string pointer is null").unwrap().into_raw() + } + } + + // extract the list of strings from the C array + let string_count = string_count as usize; + let c_strings = unsafe { std::slice::from_raw_parts(strings, string_count) }; + let rust_strings: Vec = c_strings + .iter() + .map(|&s| { + if s.is_null() { + String::new() + } else { + unsafe { CStr::from_ptr(s).to_string_lossy().into_owned() } + } + }) + .collect(); + for i in rust_strings.iter() { + if i.is_empty() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("null string passed in as key").unwrap().into_raw() + } + } + } + + let data_slice = unsafe { std::slice::from_raw_parts(data_ptr, data_length) }; + + if rust_strings.len() != data_slice.len() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("String count does not match data length").unwrap().into_raw() + } + } + + // stitch the strings and data together + let mut input_map = HashMap::new(); + for (i, key) in rust_strings.iter().enumerate() { + input_map.insert(key.clone(), data_slice[i]); + } + + let mut state = match STATE.lock() { + Ok(state) => state, + Err(error) => { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw() + } + } + }; + let mut file = match state.get_mut(&file_id) { + Some(file) => file, + None => { + { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw() + } + } + } + }; + let compute_unit = ModelComputation { + surml_file: &mut file + }; + match compute_unit.buffered_compute(&mut input_map) { + Ok(mut output) => { + let output_len = output.len(); + let output_capacity = output.capacity(); + let output_ptr = output.as_mut_ptr(); + std::mem::forget(output); + Vecf32Return { + data: output_ptr, + length: output_len, + capacity: output_capacity, + is_error: 0, + error_message: std::ptr::null_mut() + } + }, + Err(error) => { + Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error computing model: {}", error)).unwrap().into_raw() + } + } + } +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/execution/mod.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/execution/mod.rs new file mode 100644 index 0000000..590975c --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/execution/mod.rs @@ -0,0 +1,3 @@ +//! The C API for executing ML models. +pub mod raw_compute; +pub mod buffered_compute; diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/execution/raw_compute.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/execution/raw_compute.rs new file mode 100644 index 0000000..9f45038 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/execution/raw_compute.rs @@ -0,0 +1,108 @@ +//! This module contains the raw_compute function that is called from the C API to compute the model. +use crate::state::STATE; +use std::ffi::{c_float, CStr, CString, c_char}; +use surrealml_core::execution::compute::ModelComputation; +use crate::utils::Vecf32Return; + + +/// Computes the model with the given data. +/// +/// # Arguments +/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model. +/// * `data_ptr` - A pointer to the data to compute. +/// * `length` - The length of the data. +/// +/// # Returns +/// A Vecf32Return object containing the outcome of the computation. +#[no_mangle] +pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_float, length: usize) -> Vecf32Return { + + if file_id_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("File id is null").unwrap().into_raw() + } + } + if data_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("Data is null").unwrap().into_raw() + } + } + + let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() { + Ok(file_id) => file_id.to_owned(), + Err(error) => return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw() + } + }; + + let mut state = match STATE.lock() { + Ok(state) => state, + Err(error) => { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw() + } + } + }; + + let mut file = match state.get_mut(&file_id) { + Some(file) => file, + None => { + { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw() + } + } + } + }; + + let slice = unsafe { std::slice::from_raw_parts(data_ptr, length) }; + let tensor = ndarray::arr1(slice).into_dyn(); + let compute_unit = ModelComputation { + surml_file: &mut file + }; + + // perform the computation + let mut outcome = match compute_unit.raw_compute(tensor, None) { + Ok(outcome) => outcome, + Err(error) => { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error computing model: {}", error.message)).unwrap().into_raw() + } + } + }; + let outcome_ptr = outcome.as_mut_ptr(); + let outcome_len = outcome.len(); + let outcome_capacity = outcome.capacity(); + std::mem::forget(outcome); + Vecf32Return { + data: outcome_ptr, + length: outcome_len, + capacity: outcome_capacity, + is_error: 0, + error_message: std::ptr::null_mut() + } +} diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/mod.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/mod.rs new file mode 100644 index 0000000..1f0b9fa --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/mod.rs @@ -0,0 +1,4 @@ +//! C API for interacting with the SurML file storage and executing models. +pub mod execution; +pub mod storage; +pub mod ml_sys; diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/load_cached_raw_model.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/load_cached_raw_model.rs new file mode 100644 index 0000000..5f4af21 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/load_cached_raw_model.rs @@ -0,0 +1,37 @@ +//! Defines the C interface for loading an ONNX model from a file and storing it in memory. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::fs::File; +use std::io::Read; +use std::os::raw::c_char; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// Local module imports +use crate::state::{generate_unique_id, STATE}; +use crate::utils::StringReturn; +use crate::{process_string_for_string_return, string_return_safe_eject}; + + + +/// Loads a ONNX model from a file wrapping it in a SurMlFile struct +/// which is stored in memory and referenced by a unique ID. +/// +/// # Arguments +/// * `file_path` - The path to the file to load. +/// +/// # Returns +/// A unique identifier for the loaded model. +#[no_mangle] +pub extern "C" fn load_cached_raw_model(file_path_ptr: *const c_char) -> StringReturn { + let file_path_str = process_string_for_string_return!(file_path_ptr, "file path"); + let file_id = generate_unique_id(); + let mut model = string_return_safe_eject!(File::open(file_path_str)); + let mut data = vec![]; + string_return_safe_eject!(model.read_to_end(&mut data)); + let file = SurMlFile::fresh(data); + let mut python_state = STATE.lock().unwrap(); + python_state.insert(file_id.clone(), file); + StringReturn::success(file_id) +} diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/load_model.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/load_model.rs new file mode 100644 index 0000000..bb33e4f --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/load_model.rs @@ -0,0 +1,135 @@ +//! Defines the C interface for loading a surml file and getting the meta data around the model. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_int}; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// Local module imports +use crate::state::{generate_unique_id, STATE}; + + +/// Holds the data around the outcome of the load_model function. +/// +/// # Fields +/// * `file_id` - The unique identifier for the loaded model. +/// * `name` - The name of the model. +/// * `description` - The description of the model. +/// * `version` - The version of the model. +/// * `error_message` - An error message if the loading failed. +/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success). +#[repr(C)] +pub struct FileInfo { + pub file_id: *mut c_char, + pub name: *mut c_char, + pub description: *mut c_char, + pub version: *mut c_char, + pub error_message: *mut c_char, + pub is_error: c_int, +} + + +/// Frees the memory allocated for the file info. +/// +/// # Arguments +/// * `info` - The file info to free. +#[no_mangle] +pub extern "C" fn free_file_info(info: FileInfo) { + // Free all allocated strings if they are not null + if !info.file_id.is_null() { + unsafe { drop(CString::from_raw(info.file_id)) }; + } + if !info.name.is_null() { + unsafe { drop(CString::from_raw(info.name)) }; + } + if !info.description.is_null() { + unsafe { drop(CString::from_raw(info.description)) }; + } + if !info.version.is_null() { + unsafe { drop(CString::from_raw(info.version)) }; + } + if !info.error_message.is_null() { + unsafe { drop(CString::from_raw(info.error_message)) }; + } +} + +/// Loads a model from a file and returns a unique identifier for the loaded model. +/// +/// # Arguments +/// * `file_path_ptr` - A pointer to the file path of the model to load. +/// +/// # Returns +/// Meta data around the model and a unique identifier for the loaded model. +#[no_mangle] +pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo { + + // checking that the file path pointer is not null + if file_path_ptr.is_null() { + return FileInfo { + file_id: std::ptr::null_mut(), + name: std::ptr::null_mut(), + description: std::ptr::null_mut(), + version: std::ptr::null_mut(), + error_message: CString::new("Received a null pointer for file path").unwrap().into_raw(), + is_error: 1 + }; + } + + // Convert the raw C string to a Rust string + let c_str = unsafe { CStr::from_ptr(file_path_ptr) }; + + // convert the CStr into a &str + let file_path = match c_str.to_str() { + Ok(rust_str) => rust_str, + Err(_) => { + return FileInfo { + file_id: std::ptr::null_mut(), + name: std::ptr::null_mut(), + description: std::ptr::null_mut(), + version: std::ptr::null_mut(), + error_message: CString::new("Invalid UTF-8 string received for file path").unwrap().into_raw(), + is_error: 1 + }; + } + }; + + let file = match SurMlFile::from_file(&file_path) { + Ok(file) => file, + Err(e) => { + return FileInfo { + file_id: std::ptr::null_mut(), + name: std::ptr::null_mut(), + description: std::ptr::null_mut(), + version: std::ptr::null_mut(), + error_message: CString::new(e.to_string()).unwrap().into_raw(), + is_error: 1 + }; + } + }; + + // get the meta data from the file + let name = file.header.name.to_string(); + let description = file.header.description.to_string(); + let version = file.header.version.to_string(); + + // insert the file into the state + let file_id = generate_unique_id(); + let mut state = STATE.lock().unwrap(); + state.insert(file_id.clone(), file); + + // return the meta data + let file_id = CString::new(file_id).unwrap(); + let name = CString::new(name).unwrap(); + let description = CString::new(description).unwrap(); + let version = CString::new(version).unwrap(); + + FileInfo { + file_id: file_id.into_raw(), + name: name.into_raw(), + description: description.into_raw(), + version: version.into_raw(), + error_message: std::ptr::null_mut(), + is_error: 0 + } +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/meta.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/meta.rs new file mode 100644 index 0000000..d4fb797 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/meta.rs @@ -0,0 +1,211 @@ +//! Defines the C API interface for interacting with the meta data of a SurML file. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// External crate imports +use surrealml_core::storage::header::normalisers::wrapper::NormaliserType; + +// Local module imports +use crate::state::STATE; +use crate::utils::EmptyReturn; +use crate::{empty_return_safe_eject, process_string_for_empty_return}; + + + +/// Adds a name to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `model_name` - The name of the model to be added. +#[no_mangle] +pub extern "C" fn add_name(file_id_ptr: *const c_char, model_name_ptr: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let model_name = process_string_for_empty_return!(model_name_ptr, "model name"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_name(model_name); + EmptyReturn::success() +} + + +/// Adds a description to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `description` - The description of the model to be added. +#[no_mangle] +pub extern "C" fn add_description(file_id_ptr: *const c_char, description_ptr: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let description = process_string_for_empty_return!(description_ptr, "description"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_description(description); + EmptyReturn::success() +} + + +/// Adds a version to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `version` - The version of the model to be added. +#[no_mangle] +pub extern "C" fn add_version(file_id: *const c_char, version: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let version = process_string_for_empty_return!(version, "version"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + let _ = wrapped_file.header.add_version(version); + EmptyReturn::success() +} + + +/// Adds a column to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `column_name` - The name of the column to be added. +#[no_mangle] +pub extern "C" fn add_column(file_id: *const c_char, column_name: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let column_name = process_string_for_empty_return!(column_name, "column name"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_column(column_name); + EmptyReturn::success() +} + + +/// adds an author to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `author` - The author to be added. +#[no_mangle] +pub extern "C" fn add_author(file_id: *const c_char, author: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let author = process_string_for_empty_return!(author, "author"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_author(author); + EmptyReturn::success() +} + + +/// Adds an origin of where the model was trained to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `origin` - The origin to be added. +#[no_mangle] +pub extern "C" fn add_origin(file_id: *const c_char, origin: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let origin = process_string_for_empty_return!(origin, "origin"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + let _ = wrapped_file.header.add_origin(origin); + EmptyReturn::success() +} + + +/// Adds an engine to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `engine` - The engine to be added. +#[no_mangle] +pub extern "C" fn add_engine(file_id: *const c_char, engine: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let engine = process_string_for_empty_return!(engine, "engine"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_engine(engine); + EmptyReturn::success() +} + + +/// Adds an output to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `output_name` - The name of the output to be added. +/// * `normaliser_label` (Optional) - The label of the normaliser to be applied to the output. +/// * `one` (Optional) - The first parameter of the normaliser. +/// * `two` (Optional) - The second parameter of the normaliser. +#[no_mangle] +pub extern "C" fn add_output( + file_id_ptr: *const c_char, + output_name_ptr: *const c_char, + normaliser_label_ptr: *const c_char, + one: *const c_char, + two: *const c_char +) -> EmptyReturn { + + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let output_name = process_string_for_empty_return!(output_name_ptr, "output name"); + + let normaliser_label = if normaliser_label_ptr.is_null() { + None + } + else { + Some(process_string_for_empty_return!(normaliser_label_ptr, "normaliser label")) + }; + + let one = if one.is_null() { + None + } + else { + Some( + empty_return_safe_eject!(process_string_for_empty_return!(one, "one").parse::()) + ) + }; + let two = if two.is_null() { + None + } + else { + Some( + empty_return_safe_eject!(process_string_for_empty_return!(two, "two").parse::()) + ) + }; + + let mut state = STATE.lock().unwrap(); + let file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + if let Some(normaliser_label) = normaliser_label { + let normaliser = NormaliserType::new(normaliser_label, one.unwrap(), two.unwrap()); + file.header.add_output(output_name, Some(normaliser)); + } + else { + file.header.add_output(output_name, None); + } + EmptyReturn::success() +} + + +/// Adds a normaliser to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `column_name` - The name of the column to which the normaliser will be applied. +/// * `normaliser_label` - The label of the normaliser to be applied to the column. +/// * `one` - The first parameter of the normaliser. +/// * `two` - The second parameter of the normaliser. +#[no_mangle] +pub extern "C" fn add_normaliser( + file_id_ptr: *const c_char, + column_name_ptr: *const c_char, + normaliser_label_ptr: *const c_char, + one: f32, + two: f32 +) -> EmptyReturn { + + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let column_name = process_string_for_empty_return!(column_name_ptr, "column name"); + let normaliser_label = process_string_for_empty_return!(normaliser_label_ptr, "normaliser label"); + + let normaliser = NormaliserType::new(normaliser_label, one, two); + let mut state = STATE.lock().unwrap(); + let file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + let _ = file.header.normalisers.add_normaliser(normaliser, column_name, &file.header.keys); + EmptyReturn::success() +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/mod.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/mod.rs new file mode 100644 index 0000000..f049a57 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/mod.rs @@ -0,0 +1,7 @@ +//! C Storage API +pub mod load_model; +pub mod save_model; +pub mod load_cached_raw_model; +pub mod to_bytes; +pub mod meta; +pub mod upload_model; diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/save_model.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/save_model.rs new file mode 100644 index 0000000..7839999 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/save_model.rs @@ -0,0 +1,32 @@ +//! Save a model to a file, deleting the file from the `STATE` in the process. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// Local module imports +use crate::state::STATE; +use crate::utils::EmptyReturn; +use crate::{empty_return_safe_eject, process_string_for_empty_return}; + + +/// Saves a model to a file, deleting the file from the `PYTHON_STATE` in the process. +/// +/// # Arguments +/// * `file_path` - The path to the file to save to. +/// * `file_id` - The unique identifier for the loaded model. +/// +/// # Returns +/// An empty return object indicating success or failure. +#[no_mangle] +pub extern "C" fn save_model(file_path_ptr: *const c_char, file_id_ptr: *const c_char) -> EmptyReturn { + let file_path_str = process_string_for_empty_return!(file_path_ptr, "file path"); + let file_id_str = process_string_for_empty_return!(file_id_ptr, "file id"); + let mut state = STATE.lock().unwrap(); + let file: &mut SurMlFile = empty_return_safe_eject!(state.get_mut(&file_id_str), "Model not found", Option); + empty_return_safe_eject!(file.write(&file_path_str)); + state.remove(&file_id_str); + EmptyReturn::success() +} diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/to_bytes.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/to_bytes.rs new file mode 100644 index 0000000..44de0d2 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/to_bytes.rs @@ -0,0 +1,27 @@ +//! convert the entire SurML file to bytes +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// Local module imports +use crate::state::STATE; +use crate::utils::VecU8Return; +use crate::process_string_for_vec_u8_return; + + + +/// Converts the entire SurML file to bytes. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// +/// # Returns +/// A vector of bytes representing the entire file. +#[no_mangle] +pub extern "C" fn to_bytes(file_id_ptr: *const c_char) -> VecU8Return { + let file_id = process_string_for_vec_u8_return!(file_id_ptr, "file id"); + let mut state = STATE.lock().unwrap(); + let file = state.get_mut(&file_id).unwrap(); + let raw_bytes = file.to_bytes(); + VecU8Return::success(raw_bytes) +} diff --git a/modules/c-wrapper/build-context/c-wrapper/src/api/storage/upload_model.rs b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/upload_model.rs new file mode 100644 index 0000000..098aae9 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/api/storage/upload_model.rs @@ -0,0 +1,86 @@ +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// External crate imports +use base64::encode; +use hyper::{ + Body, Client, Method, Request, Uri, + header::{AUTHORIZATION, CONTENT_TYPE, HeaderValue}, +}; +use surrealml_core::storage::stream_adapter::StreamAdapter; + +// Local module imports +use crate::utils::EmptyReturn; +use crate::{empty_return_safe_eject, process_string_for_empty_return}; + + +/// Uploads a model to a remote server. +/// +/// # Arguments +/// * `file_path_ptr` - The path to the file to upload. +/// * `url_ptr` - The URL to upload the file to. +/// * `chunk_size` - The size of the chunks to upload the file in. +/// * `ns_ptr` - The namespace to upload the file to. +/// * `db_ptr` - The database to upload the file to. +/// * `username_ptr` - The username to use for authentication. +/// * `password_ptr` - The password to use for authentication. +/// +/// # Returns +/// An empty return object indicating success or failure. +#[no_mangle] +pub extern "C" fn upload_model( + file_path_ptr: *const c_char, + url_ptr: *const c_char, + chunk_size: usize, + ns_ptr: *const c_char, + db_ptr: *const c_char, + username_ptr: *const c_char, + password_ptr: *const c_char +) -> EmptyReturn { + // process the inputs + let file_path = process_string_for_empty_return!(file_path_ptr, "file path"); + let url = process_string_for_empty_return!(url_ptr, "url"); + let ns = process_string_for_empty_return!(ns_ptr, "namespace"); + let db = process_string_for_empty_return!(db_ptr, "database"); + let username = match username_ptr.is_null() { + true => None, + false => Some(process_string_for_empty_return!(username_ptr, "username")) + }; + let password = match password_ptr.is_null() { + true => None, + false => Some(process_string_for_empty_return!(password_ptr, "password")) + }; + + let client = Client::new(); + + let uri: Uri = empty_return_safe_eject!(url.parse()); + let generator = empty_return_safe_eject!(StreamAdapter::new(chunk_size, file_path)); + let body = Body::wrap_stream(generator); + + let part_req = Request::builder() + .method(Method::POST) + .uri(uri) + .header(CONTENT_TYPE, "application/octet-stream") + .header("surreal-ns", empty_return_safe_eject!(HeaderValue::from_str(&ns))) + .header("surreal-db", empty_return_safe_eject!(HeaderValue::from_str(&db))); + + let req; + if username.is_none() == false && password.is_none() == false { + // unwraps are safe because we have already checked that the values are not None + let encoded_credentials = encode(format!("{}:{}", username.unwrap(), password.unwrap())); + req = empty_return_safe_eject!(part_req.header(AUTHORIZATION, format!("Basic {}", encoded_credentials)) + .body(body)); + } + else { + req = empty_return_safe_eject!(part_req.body(body)); + } + + let tokio_runtime = empty_return_safe_eject!(tokio::runtime::Builder::new_current_thread().enable_io() + .enable_time() + .build()); + tokio_runtime.block_on( async move { + let _response = client.request(req).await.unwrap(); + }); + EmptyReturn::success() +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/c-wrapper/src/lib.rs b/modules/c-wrapper/build-context/c-wrapper/src/lib.rs new file mode 100644 index 0000000..b44dfd2 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/lib.rs @@ -0,0 +1,4 @@ +//! C lib for interacting with the SurML file storage and executing models. +mod state; +mod api; +mod utils; diff --git a/modules/c-wrapper/build-context/c-wrapper/src/state.rs b/modules/c-wrapper/build-context/c-wrapper/src/state.rs new file mode 100644 index 0000000..e8b56c4 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/state.rs @@ -0,0 +1,30 @@ +//! Defines operations for handling memory of a python program that is accessing the rust library. +// Standard library imports +use std::collections::HashMap; +use std::sync::{Arc, LazyLock, Mutex}; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// External library imports +use uuid::Uuid; + + +/// A hashmap of unique identifiers to loaded machine learning models. As long as the python program keeps the unique +/// identifier it can access the loaded machine learning model. It is best to keep as little as possible on the python +/// side and keep as much as possible on the rust side. Therefore bindings to other languages can be created with ease +/// and a command line tool can also be created without much need for new features. This will also ensure consistency +/// between other languages and the command line tool. +pub static STATE: LazyLock>>> = LazyLock::new(|| { + Arc::new(Mutex::new(HashMap::new())) +}); + + +/// Generates a unique identifier that can be used to access a loaded machine learning model. +/// +/// # Returns +/// A unique identifier that can be used to access a loaded machine learning model. +pub fn generate_unique_id() -> String { + let uuid = Uuid::new_v4(); + uuid.to_string() +} diff --git a/modules/c-wrapper/build-context/c-wrapper/src/utils.rs b/modules/c-wrapper/build-context/c-wrapper/src/utils.rs new file mode 100644 index 0000000..693093c --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/src/utils.rs @@ -0,0 +1,363 @@ +//! Defines macros and C structs for reducing the amount of boilerplate code required for the C API. +use std::os::raw::{c_char, c_int}; +use std::ffi::CString; + + +/// Checks that the pointer to the string is not null and converts to a Rust string. Any errors are returned as an `EmptyReturn`. +/// +/// # Arguments +/// * `str_ptr` - The pointer to the string. +/// * `var_name` - The name of the variable being processed (for error messages). +#[macro_export] +macro_rules! process_string_for_empty_return { + ($str_ptr:expr, $var_name:expr) => { + match $str_ptr.is_null() { + true => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Received a null pointer for {}", $var_name)).unwrap().into_raw() + }; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw() + }; + } + } + } + } + }; + ($str_ptr:expr, $var_name:expr, Option) => { + match $str_ptr.is_null() { + true => { + return None; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => Some(s.to_owned()), + Err(_) => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw() + }; + } + } + } + } + } +} + +/// Checks that the pointer to the string is not null and converts to a Rust string. Any errors are returned as a `StringReturn`. +/// +/// # Arguments +/// * `str_ptr` - The pointer to the string. +/// * `var_name` - The name of the variable being processed (for error messages). +#[macro_export] +macro_rules! process_string_for_string_return { + ($str_ptr:expr, $var_name:expr) => { + match $str_ptr.is_null() { + true => { + return StringReturn { + is_error: 1, + error_message: CString::new(format!("Received a null pointer for {}", $var_name)).unwrap().into_raw(), + string: std::ptr::null_mut() + }; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + return StringReturn { + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw(), + string: std::ptr::null_mut() + }; + } + } + } + } + }; +} + + +/// Checks that the pointer to the string is not null and converts to a Rust string. Any errors are returned as a `VecU8Return`. +/// +/// # Arguments +/// * `str_ptr` - The pointer to the string. +/// * `var_name` - The name of the variable being processed (for error messages). +#[macro_export] +macro_rules! process_string_for_vec_u8_return { + ($str_ptr:expr, $var_name:expr) => { + match $str_ptr.is_null() { + true => { + return VecU8Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Received a null pointer for {}", $var_name)).unwrap().into_raw() + }; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + return VecU8Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw() + }; + } + } + } + } + }; +} + + +/// Checks the result of an execution and returns an `StringReturn` if an error occurred. +/// +/// # Arguments +/// * `execution` - The execution such as a function call to map to `StringReturn` if an error occurred. +#[macro_export] +macro_rules! string_return_safe_eject { + ($execution:expr) => { + match $execution { + Ok(s) => s, + Err(e) => { + return StringReturn { + string: std::ptr::null_mut(), + is_error: 1, + error_message: CString::new(e.to_string()).unwrap().into_raw() + } + } + } + }; +} + + +/// Checks the result of an execution and returns an `EmptyReturn` if an error occurred or a none is returned. +/// +/// # Arguments +/// * `execution` - The execution such as a function call to map to `EmptyReturn` if an error occurred. +/// * `var` - The variable name to include in the error message. +/// * `Option` - The type of the execution. +/// +/// # Arguments +/// * `execution` - The execution such as a function call to map to `EmptyReturn` if an error occurred. +#[macro_export] +macro_rules! empty_return_safe_eject { + ($execution:expr, $var:expr, Option) => { + match $execution { + Some(s) => s, + None => { + return EmptyReturn { + is_error: 1, + error_message: CString::new($var).unwrap().into_raw() + } + } + } + }; + ($execution:expr) => { + match $execution { + Ok(s) => s, + Err(e) => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(e.to_string()).unwrap().into_raw() + } + } + } + }; +} + + +/// Returns a simple String to the caller. +/// +/// # Fields +/// * `string` - The string to return. +/// * `is_error` - A flag indicating if an error occurred (1 if error 0 if not). +/// * `error_message` - An optional error message. +#[repr(C)] +pub struct StringReturn { + pub string: *mut c_char, + pub is_error: c_int, + pub error_message: *mut c_char +} + + +impl StringReturn { + + /// Returns a new `StringReturn` object with the string and no error. + /// + /// # Arguments + /// * `string` - The string to return. + /// + /// # Returns + /// A new `StringReturn` object. + pub fn success(string: String) -> Self { + StringReturn { + string: CString::new(string).unwrap().into_raw(), + is_error: 0, + error_message: std::ptr::null_mut() + } + } +} + + +/// Frees the memory allocated for the `StringReturn` object. +/// +/// # Arguments +/// * `string_return` - The `StringReturn` object to free. +#[no_mangle] +pub extern "C" fn free_string_return(string_return: StringReturn) { + // Free the string if it is not null + if !string_return.string.is_null() { + unsafe { drop(CString::from_raw(string_return.string)) }; + } + // Free the error message if it is not null + if !string_return.error_message.is_null() { + unsafe { drop(CString::from_raw(string_return.error_message)) }; + } +} + + +/// Returns a simple empty return object to the caller. +/// +/// # Fields +/// * `is_error` - A flag indicating if an error occurred (1 if error 0 if not). +/// * `error_message` - An optional error message. +#[repr(C)] +pub struct EmptyReturn { + pub is_error: c_int, // 0 for success, 1 for error + pub error_message: *mut c_char, // Optional error message +} + +impl EmptyReturn { + + /// Returns a new `EmptyReturn` object with no error. + /// + /// # Returns + /// A new `EmptyReturn` object. + pub fn success() -> Self { + EmptyReturn { + is_error: 0, + error_message: std::ptr::null_mut() + } + } +} + + +/// Frees the memory allocated for the `EmptyReturn` object. +/// +/// # Arguments +/// * `empty_return` - The `EmptyReturn` object to free. +#[no_mangle] +pub extern "C" fn free_empty_return(empty_return: EmptyReturn) { + // Free the error message if it is not null + if !empty_return.error_message.is_null() { + unsafe { drop(CString::from_raw(empty_return.error_message)) }; + } +} + + +/// Returns a vector of bytes to the caller. +/// +/// # Fields +/// * `data` - The pointer to the data. +/// * `length` - The length of the data. +/// * `capacity` - The capacity of the data. +/// * `is_error` - A flag indicating if an error occurred (1 if error 0 if not). +/// * `error_message` - An optional error message. +#[repr(C)] +pub struct VecU8Return { + pub data: *mut u8, + pub length: usize, + pub capacity: usize, // Optional if you want to include capacity for clarity + pub is_error: c_int, + pub error_message: *mut c_char +} + + +impl VecU8Return { + + /// Returns a new `VecU8Return` object with the data and no error. + /// + /// # Arguments + /// * `data` - The data to return. + /// + /// # Returns + /// A new `VecU8Return` object. + pub fn success(data: Vec) -> Self { + let mut data = data; + let data_ptr = data.as_mut_ptr(); + let length = data.len(); + let capacity = data.capacity(); + std::mem::forget(data); + VecU8Return { + data: data_ptr, + length, + capacity, + is_error: 0, + error_message: std::ptr::null_mut() + } + } +} + + +/// Frees the memory allocated for the `VecU8Return` object. +/// +/// # Arguments +/// * `vec_u8` - The `VecU8Return` object to free. +#[no_mangle] +pub extern "C" fn free_vec_u8(vec_u8: VecU8Return) { + // Free the data if it is not null + if !vec_u8.data.is_null() { + unsafe { drop(Vec::from_raw_parts(vec_u8.data, vec_u8.length, vec_u8.capacity)) }; + } +} + + +/// Holds the data around the outcome of the raw_compute function. +/// +/// # Fields +/// * `data` - The data returned from the computation. +/// * `length` - The length of the data. +/// * `capacity` - The capacity of the data. +/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success). +/// * `error_message` - An error message if the computation failed. +#[repr(C)] +pub struct Vecf32Return { + pub data: *mut f32, + pub length: usize, + pub capacity: usize, // Optional if you want to include capacity for clarity + pub is_error: c_int, + pub error_message: *mut c_char +} + + +/// Frees the memory allocated for the Vecf32Return. +/// +/// # Arguments +/// * `vecf32_return` - The Vecf32Return to free. +#[no_mangle] +pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) { + // Free the data if it is not null + if !vecf32_return.data.is_null() { + unsafe { drop(Vec::from_raw_parts(vecf32_return.data, vecf32_return.length, vecf32_return.capacity)) }; + } + // Free the error message if it is not null + if !vecf32_return.error_message.is_null() { + unsafe { drop(CString::from_raw(vecf32_return.error_message)) }; + } +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/__init__.py b/modules/c-wrapper/build-context/c-wrapper/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/__init__.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/__init__.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/test_buffered_compute.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/test_buffered_compute.py new file mode 100644 index 0000000..c940747 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/test_buffered_compute.py @@ -0,0 +1,80 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo, Vecf32Return +from test_utils.routes import TEST_SURML_PATH + + +class TestExecution(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + + # Define the Rust function signatures + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.load_model.restype = FileInfo + + self.lib.free_file_info.argtypes = [FileInfo] + + self.lib.buffered_compute.argtypes = [ + ctypes.c_char_p, # file_id_ptr -> *const c_char + ctypes.POINTER(ctypes.c_float), # data_ptr -> *const c_float + ctypes.c_size_t, # data_length -> usize + ctypes.POINTER(ctypes.c_char_p), # strings -> *const *const c_char + ctypes.c_int # string_count -> c_int + ] + self.lib.buffered_compute.restype = Vecf32Return + + self.lib.free_vecf32_return.argtypes = [Vecf32Return] + + def test_buffered_compute(self): + # Load a test model + c_string = str(TEST_SURML_PATH).encode('utf-8') + file_info = self.lib.load_model(c_string) + + if file_info.error_message: + self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}") + + input_data = { + "squarefoot": 500.0, + "num_floors": 2.0 + } + + string_buffer = [] + data_buffer = [] + for key, value in input_data.items(): + string_buffer.append(key.encode('utf-8')) + data_buffer.append(value) + + # Prepare input data as a ctypes array + array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size + input_data = array_type(*data_buffer) # Instantiate the array with the list elements + + # prepare the input strings + string_array = (ctypes.c_char_p * len(string_buffer))(*string_buffer) + string_count = len(string_buffer) + + # Call the raw_compute function + result = self.lib.buffered_compute( + file_info.file_id, + input_data, + len(input_data), + string_array, + string_count + ) + + if result.is_error: + self.fail(f"Error in buffered_compute: {result.error_message.decode('utf-8')}") + + # Extract and verify the computation result + outcome = [result.data[i] for i in range(result.length)] + self.assertEqual(362.9851989746094, outcome[0]) + + # Free allocated memory + self.lib.free_vecf32_return(result) + self.lib.free_file_info(file_info) + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/test_raw_compute.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/test_raw_compute.py new file mode 100644 index 0000000..30dba87 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/api/execution/test_raw_compute.py @@ -0,0 +1,54 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo, Vecf32Return +from test_utils.routes import TEST_SURML_PATH + + +class TestExecution(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + + # Define the Rust function signatures + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.load_model.restype = FileInfo + + self.lib.free_file_info.argtypes = [FileInfo] + + self.lib.raw_compute.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t] + self.lib.raw_compute.restype = Vecf32Return + + self.lib.free_vecf32_return.argtypes = [Vecf32Return] + + def test_raw_compute(self): + # Load a test model + c_string = str(TEST_SURML_PATH).encode('utf-8') + file_info = self.lib.load_model(c_string) + + if file_info.error_message: + self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}") + + # Prepare input data as a ctypes array + data_buffer = [1.0, 4.0] + array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size + input_data = array_type(*data_buffer) # Instantiate the array with the list elements + + # Call the raw_compute function + result = self.lib.raw_compute(file_info.file_id, input_data, len(input_data)) + + if result.is_error: + self.fail(f"Error in raw_compute: {result.error_message.decode('utf-8')}") + + # Extract and verify the computation result + outcome = [result.data[i] for i in range(result.length)] + self.assertEqual(1.8246129751205444, outcome[0]) + + # Free allocated memory + self.lib.free_vecf32_return(result) + self.lib.free_file_info(file_info) + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/__init__.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_load_cached_raw_model.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_load_cached_raw_model.py new file mode 100644 index 0000000..8ce31d7 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_load_cached_raw_model.py @@ -0,0 +1,46 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import StringReturn +from test_utils.routes import SHOULD_BREAK_FILE, TEST_ONNX_FILE_PATH + + +class TestLoadCachedRawModel(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + # define the types + self.lib.load_cached_raw_model.restype = StringReturn + self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p] + + def test_null_pointer_protection(self): + null_pointer = None + outcome: StringReturn = self.lib.load_cached_raw_model(null_pointer) + self.assertEqual(1, outcome.is_error) + self.assertEqual("Received a null pointer for file path", outcome.error_message.decode('utf-8')) + + def test_wrong_path(self): + wrong_path = "should_break".encode('utf-8') + outcome: StringReturn = self.lib.load_cached_raw_model(wrong_path) + self.assertEqual(1, outcome.is_error) + self.assertEqual( + "No such file or directory (os error 2)", + outcome.error_message.decode('utf-8') + ) + + def test_wrong_file_format(self): + wrong_file_type = str(SHOULD_BREAK_FILE).encode('utf-8') + outcome: StringReturn = self.lib.load_cached_raw_model(wrong_file_type) + # below is unexpected and also happens in the old API + # TODO => throw an error if the file format is incorrect + self.assertEqual(0, outcome.is_error) + + def test_success(self): + right_file = str(TEST_ONNX_FILE_PATH).encode('utf-8') + outcome: StringReturn = self.lib.load_cached_raw_model(right_file) + self.assertEqual(0, outcome.is_error) + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_load_model.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_load_model.py new file mode 100644 index 0000000..0c31503 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_load_model.py @@ -0,0 +1,39 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo +from test_utils.routes import SHOULD_BREAK_FILE, TEST_SURML_PATH + + +class TestLoadModel(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + self.lib.load_model.restype = FileInfo + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.free_file_info.argtypes = [FileInfo] + + def test_null_pointer_protection(self): + null_pointer = None + outcome: FileInfo = self.lib.load_model(null_pointer) + self.assertEqual(1, outcome.is_error) + self.assertEqual("Received a null pointer for file path", outcome.error_message.decode('utf-8')) + + def test_wrong_file(self): + wrong_file_type = str(SHOULD_BREAK_FILE).encode('utf-8') + outcome: FileInfo = self.lib.load_model(wrong_file_type) + self.assertEqual(1, outcome.is_error) + self.assertEqual(True, "failed to fill whole buffer" in outcome.error_message.decode('utf-8')) + + def test_success(self): + surml_file_path = str(TEST_SURML_PATH).encode('utf-8') + outcome: FileInfo = self.lib.load_model(surml_file_path) + self.assertEqual(0, outcome.is_error) + self.lib.free_file_info(outcome) + + + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_meta.py b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_meta.py new file mode 100644 index 0000000..c4abf91 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/api/storage/test_meta.py @@ -0,0 +1,142 @@ +""" +Tests all the meta data functions +""" +import ctypes +from unittest import TestCase, main +from typing import Optional +import os + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import EmptyReturn, FileInfo, StringReturn +from test_utils.routes import TEST_SURML_PATH, TEST_ONNX_FILE_PATH, ASSETS_PATH + + +class TestMeta(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + self.lib.add_name.restype = EmptyReturn + + # Define the signatues of the basic meta functions + self.functions = [ + self.lib.add_name, + self.lib.add_description, + self.lib.add_version, + self.lib.add_column, + self.lib.add_author, + self.lib.add_origin, + self.lib.add_engine, + ] + for i in self.functions: + i.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + i.restype = EmptyReturn + + # Define the load model signature + self.lib.load_model.restype = FileInfo + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.free_file_info.argtypes = [FileInfo] + # define the load raw model signature + self.lib.load_cached_raw_model.restype = StringReturn + self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p] + # define the save model signature + self.lib.save_model.restype = EmptyReturn + self.lib.save_model.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + # load the model for tests + self.model: FileInfo = self.lib.load_model(str(TEST_SURML_PATH).encode('utf-8')) + self.file_id = self.model.file_id.decode('utf-8') + self.temp_test_id: Optional[str] = None + + def tearDown(self) -> None: + self.lib.free_file_info(self.model) + + # remove the temp surml file created in assets if present + if self.test_temp_surml_file_path is not None: + os.remove(self.test_temp_surml_file_path) + + def test_null_protection(self): + placeholder = "placeholder".encode('utf-8') + file_id = self.file_id.encode('utf-8') + + # check that they all protect against file ID null pointers + for i in self.functions: + outcome: EmptyReturn = i(None, placeholder) + self.assertEqual(1, outcome.is_error) + self.assertEqual( + "Received a null pointer for file id", + outcome.error_message.decode('utf-8') + ) + + # check that they all protect against null pointers for the field type + outcomes = [ + "model name", + "description", + "version", + "column name", + "author", + "origin", + "engine", + ] + counter = 0 + for i in self.functions: + outcome: EmptyReturn = i(file_id, None) + self.assertEqual(1, outcome.is_error) + self.assertEqual( + f"Received a null pointer for {outcomes[counter]}", + outcome.error_message.decode('utf-8') + ) + counter += 1 + + def test_model_not_found(self): + placeholder = "placeholder".encode('utf-8') + + # check they all return errors if not found + for i in self.functions: + outcome: EmptyReturn = i(placeholder, placeholder) + self.assertEqual(1, outcome.is_error) + self.assertEqual("Model not found", outcome.error_message.decode('utf-8')) + + def test_add_metadata_and_save(self): + file_id: StringReturn = self.lib.load_cached_raw_model(str(TEST_SURML_PATH).encode('utf-8')) + self.assertEqual(0, file_id.is_error) + + decoded_file_id = file_id.string.decode('utf-8') + self.temp_test_id = decoded_file_id + + self.assertEqual( + 0, + self.lib.add_name(file_id.string, "test name".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_description(file_id.string, "test description".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_version(file_id.string, "0.0.1".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_author(file_id.string, "test author".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.save_model(self.test_temp_surml_file_path.encode("utf-8"), file_id.string).is_error + ) + + outcome: FileInfo = self.lib.load_model(self.test_temp_surml_file_path.encode('utf-8')) + self.assertEqual(0, outcome.is_error) + self.assertEqual("test name", outcome.name.decode('utf-8')) + self.assertEqual("test description", outcome.description.decode('utf-8')) + self.assertEqual("0.0.1", outcome.version.decode('utf-8')) + + + @property + def test_temp_surml_file_path(self) -> Optional[str]: + if self.temp_test_id is None: + return None + return str(ASSETS_PATH.joinpath(f"{self.temp_test_id}.surml")) + + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/__init__.py b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/linear_test.onnx b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/linear_test.onnx new file mode 100644 index 0000000..f7b070b --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/linear_test.onnx @@ -0,0 +1,15 @@ +pytorch2.0.1:… +Q +onnx::MatMul_0 +onnx::MatMul_6/linear/MatMul_output_0/linear/MatMul"MatMul +; + linear.bias +/linear/MatMul_output_05 /linear/Add"Add torch_jit*B linear.biasJdÕ ²* Bonnx::MatMul_6J ‘9?[ÄŒ>Z +onnx::MatMul_0 + + +b +5 + + +B \ No newline at end of file diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/should_break.txt b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/should_break.txt new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/test.surml b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/test.surml new file mode 100644 index 0000000..61da29a Binary files /dev/null and b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/assets/test.surml differ diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/c_lib_loader.py b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/c_lib_loader.py new file mode 100644 index 0000000..00adfc6 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/c_lib_loader.py @@ -0,0 +1,56 @@ +import ctypes +import platform +from pathlib import Path +import os +from test_utils.return_structs import EmptyReturn + + +def load_library(lib_name: str = "libc_wrapper") -> ctypes.CDLL: + """ + Load the correct shared library based on the operating system. + + Args: + lib_name (str): The base name of the library without extension (e.g., "libc_wrapper"). + + Returns: + ctypes.CDLL: The loaded shared library. + """ + current_dir = Path(__file__).parent + system_name = platform.system() + + # os.environ["ORT_LIB_LOCATION"] = str(current_dir.joinpath("onnxruntime.dll")) + + if system_name == "Windows": + lib_path = current_dir.joinpath(f"{lib_name}.dll") + onnx_path = current_dir.joinpath("onnxruntime").joinpath("lib").joinpath("onnxruntime.dll") + elif system_name == "Darwin": # macOS + lib_path = current_dir.joinpath(f"{lib_name}.dylib") + onnx_path = current_dir.joinpath("onnxruntime").joinpath("lib").joinpath("onnxruntime.dylib") + elif system_name == "Linux": + lib_path = current_dir.joinpath(f"{lib_name}.so") + onnx_path = current_dir.joinpath("onnxruntime").joinpath("lib").joinpath("onnxruntime.so.1") + else: + raise OSError(f"Unsupported operating system: {system_name}") + + + # onnx_lib_path = current_dir.joinpath("onnxruntime").joinpath("lib") + # current_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + # # Update LD_LIBRARY_PATH + # os.environ["LD_LIBRARY_PATH"] = f"{onnx_lib_path}:{current_ld_library_path}" + # os.environ["ORT_LIB_LOCATION"] = str(onnx_lib_path) + + # ctypes.CDLL(str(onnx_path), mode=ctypes.RTLD_GLOBAL) + onnx_path = current_dir.joinpath("onnxruntime") + + if not lib_path.exists(): + raise FileNotFoundError(f"Shared library not found at: {lib_path}") + + loaded_lib = ctypes.CDLL(str(lib_path)) + loaded_lib.link_onnx.argtypes = [ctypes.c_char_p] + loaded_lib.link_onnx.restype = EmptyReturn + c_string = str(onnx_path).encode('utf-8') + load_info = loaded_lib.link_onnx(c_string) + if load_info.error_message: + raise OSError(f"Failed to load onnxruntime: {load_info.error_message.decode('utf-8')}") + + return ctypes.CDLL(str(lib_path)) diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/return_structs.py b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/return_structs.py new file mode 100644 index 0000000..3baaa35 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/return_structs.py @@ -0,0 +1,64 @@ +""" +Defines all the C structs that are used in the tests. +""" +from ctypes import Structure, c_char_p, c_int, c_size_t, POINTER, c_float + + +class StringReturn(Structure): + """ + A return type that just returns a string + + Fields: + string: the string that is being returned (only present if successful) + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("string", c_char_p), # Corresponds to *mut c_char + ("is_error", c_int), # Corresponds to c_int + ("error_message", c_char_p) # Corresponds to *mut c_char + ] + +class EmptyReturn(Structure): + """ + A return type that just returns nothing + + Fields: + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("is_error", c_int), # Corresponds to c_int + ("error_message", c_char_p) # Corresponds to *mut c_char + ] + + +class FileInfo(Structure): + """ + A return type when loading the meta of a surml file. + + Fields: + file_id: a unique identifier for the file in the state of the C lib + name: a name of the model + description: a description of the model + error_message: the error message (only present if error) + is_error: 1 if error, 0 if not + """ + _fields_ = [ + ("file_id", c_char_p), # Corresponds to *mut c_char + ("name", c_char_p), # Corresponds to *mut c_char + ("description", c_char_p), # Corresponds to *mut c_char + ("version", c_char_p), # Corresponds to *mut c_char + ("error_message", c_char_p), # Corresponds to *mut c_char + ("is_error", c_int) # Corresponds to c_int + ] + + +class Vecf32Return(Structure): + _fields_ = [ + ("data", POINTER(c_float)), # Pointer to f32 array + ("length", c_size_t), # Length of the array + ("capacity", c_size_t), # Capacity of the array + ("is_error", c_int), # Indicates if it's an error + ("error_message", c_char_p), # Optional error message + ] diff --git a/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/routes.py b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/routes.py new file mode 100644 index 0000000..07e07e7 --- /dev/null +++ b/modules/c-wrapper/build-context/c-wrapper/tests/test_utils/routes.py @@ -0,0 +1,12 @@ +""" +Defines all the routes for the testing module to all the assets and C libs +""" +from pathlib import Path + + +UTILS_PATH = Path(__file__).parent +ASSETS_PATH = UTILS_PATH.joinpath("assets") +TEST_SURML_PATH = ASSETS_PATH.joinpath("test.surml") +SHOULD_BREAK_FILE = ASSETS_PATH.joinpath("should_break.txt") +TEST_ONNX_FILE_PATH = ASSETS_PATH.joinpath("linear_test.onnx") +ONNX_LIB = UTILS_PATH.joinpath("..").joinpath("..").joinpath("onnx_lib").joinpath("onnxruntime") diff --git a/modules/c-wrapper/build-context/core/.dockerignore b/modules/c-wrapper/build-context/core/.dockerignore new file mode 100644 index 0000000..e0a9958 --- /dev/null +++ b/modules/c-wrapper/build-context/core/.dockerignore @@ -0,0 +1,9 @@ +.idea/ +builds/ +onnx_driver/ +target/ +tests/ +output/ +LICENSE +README.md +Cargo.lock diff --git a/modules/c-wrapper/build-context/core/.gitignore b/modules/c-wrapper/build-context/core/.gitignore new file mode 100644 index 0000000..8474aa3 --- /dev/null +++ b/modules/c-wrapper/build-context/core/.gitignore @@ -0,0 +1,6 @@ +onnx_driver/ +target/ +output/ +downloaded_onnx_package/ +src/execution/libonnxruntime.a +libonnxruntime.* diff --git a/modules/c-wrapper/build-context/core/Cargo.toml b/modules/c-wrapper/build-context/core/Cargo.toml new file mode 100644 index 0000000..8d2dac1 --- /dev/null +++ b/modules/c-wrapper/build-context/core/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "surrealml-core" +version = "0.1.3" +edition = "2021" +build = "./build.rs" +description = "The core machine learning library for SurrealML that enables SurrealDB to store and load ML models" +license-file = "LICENSE" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +axum-feature = ["axum"] +actix-feature = ["actix-web"] +default = [] +# below are the features for testing different engines +sklearn-tests = [] +onnx-tests = [] +torch-tests = [] +tensorflow-tests = [] +gpu = [] +dynamic = ["ort/load-dynamic"] + +[dependencies] +regex = "1.9.3" +ort = { version = "2.0.0-rc.9", features = [ "cuda", "ndarray" ]} +ndarray = "0.16.1" +once_cell = "1.18.0" +bytes = "1.5.0" +futures-util = "0.3.28" +futures-core = "0.3.28" +thiserror = "2.0.9" +serde = { version = "1.0.197", features = ["derive"] } +axum = { version = "0.7.4", optional = true } +actix-web = { version = "4.5.1", optional = true } + + +[dev-dependencies] +tokio = { version = "1.12.0", features = ["full"] } + +[lib] +name = "surrealml_core" +path = "src/lib.rs" + +# [build-dependencies] +# ort = { version = "1.16.2", default-features = true } diff --git a/modules/c-wrapper/build-context/core/Dockerfile b/modules/c-wrapper/build-context/core/Dockerfile new file mode 100644 index 0000000..c4f15ef --- /dev/null +++ b/modules/c-wrapper/build-context/core/Dockerfile @@ -0,0 +1,37 @@ +# Use an official Rust image +FROM rust:1.83-slim + +# Install necessary tools +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Set the working directory +WORKDIR /app + +# Copy the project files into the container +COPY . . + +# Download ONNX Runtime 1.20.0 +RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz \ + && tar -xvf onnxruntime-linux-x64-1.20.0.tgz \ + && mv onnxruntime-linux-x64-1.20.0 /onnxruntime + +# # Download ONNX Runtime 1.16.0 +# RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.0/onnxruntime-linux-x64-1.16.0.tgz \ +# && tar -xvf onnxruntime-linux-x64-1.16.0.tgz \ +# && mv onnxruntime-linux-x64-1.16.0 /onnxruntime + +# Set the ONNX Runtime library path +ENV ORT_LIB_LOCATION=/onnxruntime/lib +ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +# Clean and build the Rust project +RUN cargo clean +RUN cargo build --features tensorflow-tests + +# Run the tests +CMD ["cargo", "test", "--features", "tensorflow-tests"] diff --git a/modules/c-wrapper/build-context/core/LICENSE b/modules/c-wrapper/build-context/core/LICENSE new file mode 100644 index 0000000..4e81b1e --- /dev/null +++ b/modules/c-wrapper/build-context/core/LICENSE @@ -0,0 +1,103 @@ +Business Source License 1.1 + +Parameters + +Licensor: SurrealDB Ltd. +Licensed Work: Surrealml + The Licensed Work is (c) 2022 SurrealDB Ltd. +Additional Use Grant: You may make use of the Licensed Work, provided that + you may not use the Licensed Work for a Database + Service. + + A “Database Service†is a commercial offering that + allows third parties (other than your employees and + contractors) to access the functionality of the + Licensed Work by creating tables whose schemas are + controlled by such third parties. + +Change Date: has not changed yet + +Change License: Apache License, Version 2.0 + +For information about alternative licensing arrangements for the Software, +please visit: https://surrealdb.com + +Notice + +The Business Source License (this document, or the “Licenseâ€) is not an Open +Source license. However, the Licensed Work will eventually be made available +under an Open Source License, as stated in this License. + +License text copyright (c) 2017 MariaDB Corporation Ab, All Rights Reserved. +“Business Source License†is a trademark of MariaDB Corporation Ab. + +----------------------------------------------------------------------------- + +Business Source License 1.1 + +Terms + +The Licensor hereby grants you the right to copy, modify, create derivative +works, redistribute, and make non-production use of the Licensed Work. The +Licensor may make an Additional Use Grant, above, permitting limited +production use. + +Effective on the Change Date, or the fourth anniversary of the first publicly +available distribution of a specific version of the Licensed Work under this +License, whichever comes first, the Licensor hereby grants you rights under +the terms of the Change License, and the rights granted in the paragraph +above terminate. + +If your use of the Licensed Work does not comply with the requirements +currently in effect as described in this License, you must purchase a +commercial license from the Licensor, its affiliated entities, or authorized +resellers, or you must refrain from using the Licensed Work. + +All copies of the original and modified Licensed Work, and derivative works +of the Licensed Work, are subject to this License. This License applies +separately for each version of the Licensed Work and the Change Date may vary +for each version of the Licensed Work released by Licensor. + +You must conspicuously display this License on each original or modified copy +of the Licensed Work. If you receive the Licensed Work in original or +modified form from a third party, the terms and conditions set forth in this +License apply to your use of that work. + +Any use of the Licensed Work in violation of this License will automatically +terminate your rights under this License for the current and all other +versions of the Licensed Work. + +This License does not grant you any right in any trademark or logo of +Licensor or its affiliates (provided that you may use a trademark or logo of +Licensor as expressly required by this License). + +TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE LICENSED WORK IS PROVIDED ON +AN “AS IS†BASIS. LICENSOR HEREBY DISCLAIMS ALL WARRANTIES AND CONDITIONS, +EXPRESS OR IMPLIED, INCLUDING (WITHOUT LIMITATION) WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, AND +TITLE. + +MariaDB hereby grants you permission to use this License’s text to license +your works, and to refer to it using the trademark “Business Source Licenseâ€, +as long as you comply with the Covenants of Licensor below. + +Covenants of Licensor + +In consideration of the right to use this License’s text and the “Business +Source License†name and trademark, Licensor covenants to MariaDB, and to all +other recipients of the licensed work to be provided by Licensor: + +1. To specify as the Change License the GPL Version 2.0 or any later version, + or a license that is compatible with GPL Version 2.0 or a later version, + where “compatible†means that software provided under the Change License can + be included in a program with software provided under GPL Version 2.0 or a + later version. Licensor may specify additional Change Licenses without + limitation. + +2. To either: (a) specify an additional grant of rights to use that does not + impose any additional restriction on the right granted in this License, as + the Additional Use Grant; or (b) insert the text “Noneâ€. + +3. To specify a Change Date. + +4. Not to modify this License in any other way. \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/README.md b/modules/c-wrapper/build-context/core/README.md new file mode 100644 index 0000000..632458d --- /dev/null +++ b/modules/c-wrapper/build-context/core/README.md @@ -0,0 +1,105 @@ + +# Surml Core + +An embedded ONNX runtime directly in the Rust binary when compiling result in no need for installing ONNX runtime separately or worrying about version clashes with other runtimes. + +This crate is just the Rust implementation of the Surml API. It is advised that you just use this crate directly if you are running a Rust server. It must be noted that the version of ONNX needs to be the same as the client when using this crate. For this current version of Surml, the ONNX version is `1.16.0`. + +## Compilation config + +If nothing is configured the crate will compile the ONNX runtime into the binary. This is the default behaviour. However, if you want to use an ONNX runtime that is installed on your system, you can set the environment variable `ONNXRUNTIME_LIB_PATH` before you compile the crate. This will make the crate use the ONNX runtime that is installed on your system. + +This houses reusable errors that are used across all the crates in the Surml ecosystem, and these errors can construct HTTP responses for the Axum and Actix web frameworks. + +## Nix Support + +At this point in time NIX is not directly supported. The `ONNXRUNTIME_LIB_PATH` needs to be defined. This is explained in the `Compilation config` section. + +## Usage + +Surml can be used to store, load, and execute ONNX models. + +### Storing and accessing models +We can store models and meta data around the models with the following code: +```rust +use std::fs::File; +use std::io::{self, Read, Write}; + +use surrealml_core::storage::surml_file::SurMlFile; +use surrealml_core::storage::header::Header; +use surrealml_core::storage::header::normalisers::{ + wrapper::NormaliserType, + linear_scaling::LinearScaling +}; + + +// load your own model here (surrealml python package can be used to convert PyTorch, +// and Sklearn models to ONNX or package them as surml files) +let mut file = File::open("./stash/linear_test.onnx").unwrap(); +let mut model_bytes = Vec::new(); +file.read_to_end(&mut model_bytes).unwrap(); + +// create a header for the model +let mut header = Header::fresh(); +header.add_column(String::from("squarefoot")); +header.add_column(String::from("num_floors")); +header.add_output(String::from("house_price"), None); + +// add normalisers if needed +header.add_normaliser( + "squarefoot".to_string(), + NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }) +); +header.add_normaliser( + "num_floors".to_string(), + NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }) +); + +// create a surml file +let surml_file = SurMlFile::new(header, model_bytes); + +// read and write surml files +surml_file.write("./stash/test.surml").unwrap(); +let new_file = SurMlFile::from_file("./stash/test.surml").unwrap(); +let file_from_bytes = SurMlFile::from_bytes(surml_file.to_bytes()).unwrap(); +``` + +## Executing models + +We you load a `surml` file, you can execute the model with the following code: + +```rust +use surrealml_core::storage::surml_file::SurMlFile; +use surrealml_core::execution::compute::ModelComputation; +use ndarray::ArrayD; +use std::collections::HashMap; + + +let mut file = SurMlFile::from_file("./stash/test.surml").unwrap(); + +let compute_unit = ModelComputation { + surml_file: &mut file, +}; + +// automatically map inputs and apply normalisers to the compute if this data was put in the header +let mut input_values = HashMap::new(); +input_values.insert(String::from("squarefoot"), 1000.0); +input_values.insert(String::from("num_floors"), 2.0); + +let output = compute_unit.buffered_compute(&mut input_values).unwrap(); + +// feed a raw ndarray into the model if no header was provided or if you want to bypass the header +let x = vec![1000.0, 2.0]; +let data: ArrayD = ndarray::arr1(&x).into_dyn(); + +// None input can be a tuple of dimensions of the input data +let output = compute_unit.raw_compute(data, None).unwrap(); +``` + +## ONNX runtime assets + +We can find the ONNX assets with the following link: + +``` +https://github.com/microsoft/onnxruntime/releases/tag/v1.16.2 +``` diff --git a/modules/c-wrapper/build-context/core/build.rs b/modules/c-wrapper/build-context/core/build.rs new file mode 100644 index 0000000..a1fd6aa --- /dev/null +++ b/modules/c-wrapper/build-context/core/build.rs @@ -0,0 +1,4 @@ + +fn main() { + +} diff --git a/modules/c-wrapper/build-context/core/builds/Dockerfile.linux b/modules/c-wrapper/build-context/core/builds/Dockerfile.linux new file mode 100644 index 0000000..14b43ec --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/Dockerfile.linux @@ -0,0 +1,17 @@ +# Start from a base image, e.g., Ubuntu +FROM ubuntu:latest + +RUN apt-get update && apt-get install -y curl build-essential +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y + +ENV PATH="/root/.cargo/bin:${PATH}" + +WORKDIR /app + +COPY . . + +# RUN cargo build --release + +# CMD ["cargo", "test"] +# run in infinite loop +CMD ["tail", "-f", "/dev/null"] diff --git a/modules/c-wrapper/build-context/core/builds/Dockerfile.macos b/modules/c-wrapper/build-context/core/builds/Dockerfile.macos new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/build-context/core/builds/Dockerfile.nix b/modules/c-wrapper/build-context/core/builds/Dockerfile.nix new file mode 100644 index 0000000..4522732 --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/Dockerfile.nix @@ -0,0 +1,21 @@ +# Start from a base image, e.g., Ubuntu +FROM nixos/nix:latest + +# Update Nix channel +RUN nix-channel --update + +# Install Rust and build tools using Nix +RUN nix-env -iA nixpkgs.rustup nixpkgs.gcc nixpkgs.pkg-config nixpkgs.cmake nixpkgs.coreutils + +# Initialize Rust environment +RUN rustup default stable + +ENV PATH="/root/.cargo/bin:${PATH}" + +WORKDIR /app + +COPY . . + +# RUN cargo build --release + +CMD ["cargo", "run"] diff --git a/modules/c-wrapper/build-context/core/builds/Dockerfile.windows b/modules/c-wrapper/build-context/core/builds/Dockerfile.windows new file mode 100644 index 0000000..2206c1e --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/Dockerfile.windows @@ -0,0 +1,37 @@ +# # Use a Windows base image +# FROM mcr.microsoft.com/dotnet/core/sdk:2.1 + +# # Install Rust +# RUN powershell -Command \ +# $ErrorActionPreference = 'Stop'; \ +# [Net.ServicePointManager]::SecurityProtocol = [Net.SecurityProtocolType]::Tls12; \ +# Invoke-WebRequest https://win.rustup.rs -OutFile rustup-init.exe; \ +# Start-Process ./rustup-init.exe -ArgumentList '-y' -Wait; \ +# Remove-Item rustup-init.exe + +# # Add Cargo to PATH +# ENV PATH="C:\\Users\\ContainerAdministrator\\.cargo\\bin;${PATH}" + +# WORKDIR /app +# COPY . . + +# # Command to run on container start +# CMD ["cargo", "run"] + +# Use the latest Windows Server Core image +FROM mcr.microsoft.com/windows:ltsc2019 + +# Set the working directory to C:\app +WORKDIR C:\app + +# Install Rust +RUN powershell.exe -Command "Invoke-WebRequest https://win.rustup.rs -OutFile rustup-init.exe; .\rustup-init.exe -y" + +# Add Rust to the PATH environment variable +RUN setx /M PATH $('C:\Users\ContainerAdministrator\.cargo\bin;' + $Env:PATH) + +# Copy the source code into the container +COPY . . + +# Run the application +CMD ["cargo", "run"] diff --git a/modules/c-wrapper/build-context/core/builds/docker_configs/linux.yml b/modules/c-wrapper/build-context/core/builds/docker_configs/linux.yml new file mode 100644 index 0000000..88d591b --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/docker_configs/linux.yml @@ -0,0 +1,15 @@ +version: "3.8" + +services: + linux_surrealml_core: + build: + context: . + dockerfile: builds/Dockerfile.linux + restart: unless-stopped + # command: tail -f /dev/null + environment: + TEST: test_env + volumes: + - ./output/linux:/app/output + ports: + - "8001:8001" diff --git a/modules/c-wrapper/build-context/core/builds/docker_configs/macos.yml b/modules/c-wrapper/build-context/core/builds/docker_configs/macos.yml new file mode 100644 index 0000000..758357d --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/docker_configs/macos.yml @@ -0,0 +1,12 @@ +version: "3.8" + +services: + surrealml_core: + build: + context: . + dockerfile: builds/Dockerfile.macos + restart: unless-stopped + environment: + TEST: test_env + ports: + - "8001:8001" diff --git a/modules/c-wrapper/build-context/core/builds/docker_configs/nix.yml b/modules/c-wrapper/build-context/core/builds/docker_configs/nix.yml new file mode 100644 index 0000000..ee3323d --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/docker_configs/nix.yml @@ -0,0 +1,13 @@ +version: "3.8" + +services: + nix_surrealml_core: + build: + context: . + dockerfile: builds/Dockerfile.nix + restart: unless-stopped + command: tail -f /dev/null + environment: + TEST: test_env + ports: + - "8001:8001" diff --git a/modules/c-wrapper/build-context/core/builds/docker_configs/windows.yml b/modules/c-wrapper/build-context/core/builds/docker_configs/windows.yml new file mode 100644 index 0000000..3961582 --- /dev/null +++ b/modules/c-wrapper/build-context/core/builds/docker_configs/windows.yml @@ -0,0 +1,13 @@ +version: "3.8" + +services: + windows_surrealml_core: + build: + context: . + dockerfile: builds/Dockerfile.windows + restart: unless-stopped + command: tail -f /dev/null + environment: + TEST: test_env + ports: + - "8001:8001" diff --git a/modules/c-wrapper/build-context/core/docker-compose.yml b/modules/c-wrapper/build-context/core/docker-compose.yml new file mode 100644 index 0000000..98965f7 --- /dev/null +++ b/modules/c-wrapper/build-context/core/docker-compose.yml @@ -0,0 +1,8 @@ +version: '3.8' + +services: + + busybox: + image: busybox + # command: tail -f /dev/null + command: echo "Hello World" diff --git a/modules/c-wrapper/build-context/core/model_stash/onnx/onnx/linear.onnx b/modules/c-wrapper/build-context/core/model_stash/onnx/onnx/linear.onnx new file mode 100644 index 0000000..bc1dbf6 Binary files /dev/null and b/modules/c-wrapper/build-context/core/model_stash/onnx/onnx/linear.onnx differ diff --git a/modules/c-wrapper/build-context/core/model_stash/onnx/surml/linear.surml b/modules/c-wrapper/build-context/core/model_stash/onnx/surml/linear.surml new file mode 100644 index 0000000..f092b50 Binary files /dev/null and b/modules/c-wrapper/build-context/core/model_stash/onnx/surml/linear.surml differ diff --git a/modules/c-wrapper/build-context/core/model_stash/sklearn/onnx/linear.onnx b/modules/c-wrapper/build-context/core/model_stash/sklearn/onnx/linear.onnx new file mode 100644 index 0000000..bc1dbf6 Binary files /dev/null and b/modules/c-wrapper/build-context/core/model_stash/sklearn/onnx/linear.onnx differ diff --git a/modules/c-wrapper/build-context/core/model_stash/sklearn/surml/linear.surml b/modules/c-wrapper/build-context/core/model_stash/sklearn/surml/linear.surml new file mode 100644 index 0000000..f092b50 Binary files /dev/null and b/modules/c-wrapper/build-context/core/model_stash/sklearn/surml/linear.surml differ diff --git a/modules/c-wrapper/build-context/core/model_stash/tensorflow/surml/linear.surml b/modules/c-wrapper/build-context/core/model_stash/tensorflow/surml/linear.surml new file mode 100644 index 0000000..628e257 Binary files /dev/null and b/modules/c-wrapper/build-context/core/model_stash/tensorflow/surml/linear.surml differ diff --git a/modules/c-wrapper/build-context/core/model_stash/torch/surml/linear.surml b/modules/c-wrapper/build-context/core/model_stash/torch/surml/linear.surml new file mode 100644 index 0000000..59864af Binary files /dev/null and b/modules/c-wrapper/build-context/core/model_stash/torch/surml/linear.surml differ diff --git a/modules/c-wrapper/build-context/core/onnxruntime-linux-x64-1.20.0.tgz b/modules/c-wrapper/build-context/core/onnxruntime-linux-x64-1.20.0.tgz new file mode 100644 index 0000000..2cacda3 Binary files /dev/null and b/modules/c-wrapper/build-context/core/onnxruntime-linux-x64-1.20.0.tgz differ diff --git a/modules/c-wrapper/build-context/core/scripts/install_onnxruntime_linux.sh b/modules/c-wrapper/build-context/core/scripts/install_onnxruntime_linux.sh new file mode 100644 index 0000000..c76605c --- /dev/null +++ b/modules/c-wrapper/build-context/core/scripts/install_onnxruntime_linux.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Variables +ONNX_VERSION="1.20.0" +ONNX_DOWNLOAD_URL="https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz" +ONNX_RUNTIME_DIR="/home/maxwellflitton/Documents/github/surreal/surrealml/modules/core/target/debug/build/ort-680c63907dcb00d8/out/onnxruntime" +ONNX_TARGET_DIR="${ONNX_RUNTIME_DIR}/onnxruntime-linux-x64-${ONNX_VERSION}" +LD_LIBRARY_PATH_UPDATE="${ONNX_TARGET_DIR}/lib" + +# Step 1: Download and Extract ONNX Runtime +echo "Downloading ONNX Runtime version ${ONNX_VERSION}..." +wget -q --show-progress "${ONNX_DOWNLOAD_URL}" -O "onnxruntime-linux-x64-${ONNX_VERSION}.tgz" + +if [ $? -ne 0 ]; then + echo "Failed to download ONNX Runtime. Exiting." + exit 1 +fi + +echo "Extracting ONNX Runtime..." +tar -xvf "onnxruntime-linux-x64-${ONNX_VERSION}.tgz" + +if [ ! -d "onnxruntime-linux-x64-${ONNX_VERSION}" ]; then + echo "Extraction failed. Directory not found. Exiting." + exit 1 +fi + +# Step 2: Replace Old ONNX Runtime +echo "Replacing old ONNX Runtime..." +mkdir -p "${ONNX_RUNTIME_DIR}" +mv "onnxruntime-linux-x64-${ONNX_VERSION}" "${ONNX_TARGET_DIR}" + +if [ ! -d "${ONNX_TARGET_DIR}" ]; then + echo "Failed to move ONNX Runtime to target directory. Exiting." + exit 1 +fi + +# Step 3: Update LD_LIBRARY_PATH +echo "Updating LD_LIBRARY_PATH..." +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH_UPDATE}:$LD_LIBRARY_PATH" + +# Step 4: Verify Library Version +echo "Verifying ONNX Runtime version..." +strings "${LD_LIBRARY_PATH_UPDATE}/libonnxruntime.so" | grep "VERS_${ONNX_VERSION}" > /dev/null + +if [ $? -ne 0 ]; then + echo "ONNX Runtime version ${ONNX_VERSION} not found in library. Exiting." + exit 1 +fi + +# Step 5: Install Library Globally (Optional) +echo "Installing ONNX Runtime globally..." +sudo cp "${LD_LIBRARY_PATH_UPDATE}/libonnxruntime.so" /usr/local/lib/ +sudo ldconfig + +if [ $? -ne 0 ]; then + echo "Failed to install ONNX Runtime globally. Exiting." + exit 1 +fi + +# Step 6: Clean and Rebuild Project +echo "Cleaning and rebuilding project..." +cargo clean +cargo test --features tensorflow-tests + +if [ $? -eq 0 ]; then + echo "ONNX Runtime updated successfully, and tests passed." +else + echo "ONNX Runtime updated, but tests failed. Check the logs for details." +fi diff --git a/modules/c-wrapper/build-context/core/scripts/linux_compose.sh b/modules/c-wrapper/build-context/core/scripts/linux_compose.sh new file mode 100644 index 0000000..eae85ea --- /dev/null +++ b/modules/c-wrapper/build-context/core/scripts/linux_compose.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# compose_command=$1 + +# docker-compose -f docker-compose.yml -f aarch.yml $1 +docker-compose -f docker-compose.yml -f builds/docker_configs/linux.yml $1 \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/scripts/nix_compose.sh b/modules/c-wrapper/build-context/core/scripts/nix_compose.sh new file mode 100644 index 0000000..a319fcb --- /dev/null +++ b/modules/c-wrapper/build-context/core/scripts/nix_compose.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# compose_command=$1 + +# docker-compose -f docker-compose.yml -f aarch.yml $1 +docker-compose -f docker-compose.yml -f builds/docker_configs/nix.yml $1 \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/scripts/windows_compose.sh b/modules/c-wrapper/build-context/core/scripts/windows_compose.sh new file mode 100644 index 0000000..a19af2a --- /dev/null +++ b/modules/c-wrapper/build-context/core/scripts/windows_compose.sh @@ -0,0 +1,12 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# compose_command=$1 + +# docker-compose -f docker-compose.yml -f aarch.yml $1 +docker-compose -f docker-compose.yml -f builds/docker_configs/windows.yml $1 diff --git a/modules/c-wrapper/build-context/core/src/errors/actix.rs b/modules/c-wrapper/build-context/core/src/errors/actix.rs new file mode 100644 index 0000000..3ad0b95 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/errors/actix.rs @@ -0,0 +1,57 @@ +//! Implements the `ResponseError` trait for the `SurrealError` type for the `actix_web` web framework. +use actix_web::{HttpResponse, error::ResponseError, http::StatusCode}; +pub use crate::errors::error::{SurrealErrorStatus, SurrealError}; + + +impl ResponseError for SurrealError { + + /// Yields the status code for the error. + /// + /// # Returns + /// * `StatusCode` - The status code for the error. + fn status_code(&self) -> StatusCode { + match self.status { + SurrealErrorStatus::NotFound => StatusCode::NOT_FOUND, + SurrealErrorStatus::Forbidden => StatusCode::FORBIDDEN, + SurrealErrorStatus::Unknown => StatusCode::INTERNAL_SERVER_ERROR, + SurrealErrorStatus::BadRequest => StatusCode::BAD_REQUEST, + SurrealErrorStatus::Conflict => StatusCode::CONFLICT, + SurrealErrorStatus::Unauthorized => StatusCode::UNAUTHORIZED + } + } + + /// Constructs a HTTP response for the error. + /// + /// # Returns + /// * `HttpResponse` - The HTTP response for the error. + fn error_response(&self) -> HttpResponse { + let status_code = self.status_code(); + HttpResponse::build(status_code).json(self.message.clone()) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use actix_web::http::StatusCode; + + #[test] + fn test_status_code() { + let error = SurrealError { + message: "Test".to_string(), + status: SurrealErrorStatus::NotFound + }; + assert_eq!(error.status_code(), StatusCode::NOT_FOUND); + } + + #[test] + fn test_error_response() { + let error = SurrealError { + message: "Test".to_string(), + status: SurrealErrorStatus::NotFound + }; + let response = error.error_response(); + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } +} diff --git a/modules/c-wrapper/build-context/core/src/errors/axum.rs b/modules/c-wrapper/build-context/core/src/errors/axum.rs new file mode 100644 index 0000000..b5a3821 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/errors/axum.rs @@ -0,0 +1,44 @@ +//! Implements the `IntoResponse` trait for the `SurrealError` type for the `axum` web framework. +use axum::response::{IntoResponse, Response}; +use axum::body::Body; +pub use crate::errors::error::{SurrealErrorStatus, SurrealError}; + + +impl IntoResponse for SurrealError { + + /// Constructs a HTTP response for the error. + /// + /// # Returns + /// * `Response` - The HTTP response for the error. + fn into_response(self) -> Response { + let status_code = match self.status { + SurrealErrorStatus::NotFound => axum::http::StatusCode::NOT_FOUND, + SurrealErrorStatus::Forbidden => axum::http::StatusCode::FORBIDDEN, + SurrealErrorStatus::Unknown => axum::http::StatusCode::INTERNAL_SERVER_ERROR, + SurrealErrorStatus::BadRequest => axum::http::StatusCode::BAD_REQUEST, + SurrealErrorStatus::Conflict => axum::http::StatusCode::CONFLICT, + SurrealErrorStatus::Unauthorized => axum::http::StatusCode::UNAUTHORIZED + }; + axum::http::Response::builder() + .status(status_code) + .body(Body::new(self.message)) + .unwrap() + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::StatusCode; + + #[test] + fn test_into_response() { + let error = SurrealError { + message: "Test".to_string(), + status: SurrealErrorStatus::NotFound + }; + let response = error.into_response(); + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } +} diff --git a/modules/c-wrapper/build-context/core/src/errors/error.rs b/modules/c-wrapper/build-context/core/src/errors/error.rs new file mode 100644 index 0000000..edc30f2 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/errors/error.rs @@ -0,0 +1,101 @@ +//! Custom error that can be attached to a web framework to automcatically result in a http response, +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use std::fmt; + + +#[macro_export] +macro_rules! safe_eject { + // Match when the optional string is provided + ($e:expr, $err_status:expr, $msg:expr) => { + $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, $err_status)})? + }; + // Match when the optional string is not provided + ($e:expr, $err_status:expr) => { + $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, $err_status)})? + }; +} + + +#[macro_export] +macro_rules! safe_eject_internal { + // Match when the optional string is provided + ($e:expr, $err_status:expr, $msg:expr) => { + $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, SurrealErrorStatus::Unknown)})? + }; + // Match when the optional string is not provided + ($e:expr) => { + $e.map_err(|x| {let file_track = format!("{}:{}", file!(), line!()); let formatted_error = format!("{} => {}", file_track, x.to_string()); SurrealError::new(formatted_error, SurrealErrorStatus::Unknown)})? + }; +} + + +#[macro_export] +macro_rules! safe_eject_option { + ($check:expr) => { + match $check {Some(x) => x, None => {let file_track = format!("{}:{}", file!(), line!());let message = format!("{}=>The value is not found", file_track);return Err(SurrealError::new(message, SurrealErrorStatus::NotFound))}} + }; +} + + +/// The status of the custom error. +/// +/// # Fields +/// * `NotFound` - The request was not found. +/// * `Forbidden` - You are forbidden to access. +/// * `Unknown` - An unknown internal error occurred. +/// * `BadRequest` - The request was bad. +/// * `Conflict` - The request conflicted with the current state of the server. +#[derive(Error, Debug, Serialize, Deserialize, PartialEq)] +pub enum SurrealErrorStatus { + #[error("not found")] + NotFound, + #[error("You are forbidden to access resource")] + Forbidden, + #[error("Unknown Internal Error")] + Unknown, + #[error("Bad Request")] + BadRequest, + #[error("Conflict")] + Conflict, + #[error("Unauthorized")] + Unauthorized +} + + +/// The custom error that the web framework will construct into a HTTP response. +/// +/// # Fields +/// * `message` - The message of the error. +/// * `status` - The status of the error. +#[derive(Serialize, Deserialize, Debug, Error)] +pub struct SurrealError { + pub message: String, + pub status: SurrealErrorStatus +} + + +impl SurrealError { + + /// Create a new custom error. + /// + /// # Arguments + /// * `message` - The message of the error. + /// * `status` - The status of the error. + /// + /// # Returns + /// A new custom error. + pub fn new(message: String, status: SurrealErrorStatus) -> Self { + SurrealError { + message, + status + } + } +} + + +impl fmt::Display for SurrealError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} diff --git a/modules/c-wrapper/build-context/core/src/errors/mod.rs b/modules/c-wrapper/build-context/core/src/errors/mod.rs new file mode 100644 index 0000000..e61abff --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/errors/mod.rs @@ -0,0 +1,7 @@ +pub mod error; + +#[cfg(feature = "actix-feature")] +pub mod actix; + +#[cfg(feature = "axum-feature")] +pub mod axum; diff --git a/modules/c-wrapper/build-context/core/src/execution/compute.rs b/modules/c-wrapper/build-context/core/src/execution/compute.rs new file mode 100644 index 0000000..584575e --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/execution/compute.rs @@ -0,0 +1,314 @@ +//! Defines the operations around performing computations on a loaded model. +use crate::storage::surml_file::SurMlFile; +use std::collections::HashMap; +use ndarray::ArrayD; +use ort::value::ValueType; +use ort::session::Session; + +use crate::safe_eject; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use crate::execution::session::get_session; + + +/// A wrapper for the loaded machine learning model so we can perform computations on the loaded model. +/// +/// # Attributes +/// * `surml_file` - The loaded machine learning model using interior mutability to allow mutable access to the model +pub struct ModelComputation<'a> { + pub surml_file: &'a mut SurMlFile, +} + + +impl <'a>ModelComputation<'a> { + + /// Creates a Tensor that can be used as input to the loaded model from a hashmap of keys and values. + /// + /// # Arguments + /// * `input_values` - A hashmap of keys and values that will be used to create the input tensor. + /// + /// # Returns + /// A Tensor that can be used as input to the loaded model. + pub fn input_tensor_from_key_bindings(&self, input_values: HashMap) -> Result, SurrealError> { + let buffer = self.input_vector_from_key_bindings(input_values)?; + Ok(ndarray::arr1::(&buffer).into_dyn()) + } + + /// Creates a vector of dimensions for the input tensor from the loaded model. + /// + /// # Arguments + /// * `input_dims` - The input dimensions from the loaded model. + /// + /// # Returns + /// A vector of dimensions for the input tensor to be reshaped into from the loaded model. + fn process_input_dims(session_ref: &Session) -> Vec { + let some_dims = match &session_ref.inputs[0].input_type { + ValueType::Tensor{ ty: _, dimensions: new_dims, dimension_symbols: _ } => Some(new_dims), + _ => None + }; + let mut dims_cache = Vec::new(); + for dim in some_dims.unwrap() { + if dim < &0 { + dims_cache.push((dim * -1) as usize); + } + else { + dims_cache.push(*dim as usize); + } + } + dims_cache + } + + /// Creates a Vector that can be used manipulated with other operations such as normalisation from a hashmap of keys and values. + /// + /// # Arguments + /// * `input_values` - A hashmap of keys and values that will be used to create the input vector. + /// + /// # Returns + /// A Vector that can be used manipulated with other operations such as normalisation. + pub fn input_vector_from_key_bindings(&self, mut input_values: HashMap) -> Result, SurrealError> { + let mut buffer = Vec::with_capacity(self.surml_file.header.keys.store.len()); + + for key in &self.surml_file.header.keys.store { + let value = match input_values.get_mut(key) { + Some(value) => value, + None => return Err(SurrealError::new(format!("src/execution/compute.rs 67: Key {} not found in input values", key), SurrealErrorStatus::NotFound)) + }; + buffer.push(std::mem::take(value)); + } + + Ok(buffer) + } + + /// Performs a raw computation on the loaded model. + /// + /// # Arguments + /// * `tensor` - The input tensor to the loaded model. + /// + /// # Returns + /// The computed output tensor from the loaded model. + pub fn raw_compute(&self, tensor: ArrayD, _dims: Option<(i32, i32)>) -> Result, SurrealError> { + let session = get_session(self.surml_file.model.clone())?; + let dims_cache = ModelComputation::process_input_dims(&session); + let tensor = match tensor.into_shape_with_order(dims_cache) { + Ok(tensor) => tensor, + Err(_) => return Err(SurrealError::new("Failed to reshape tensor to input dimensions".to_string(), SurrealErrorStatus::Unknown)) + }; + let tensor = match ort::value::Tensor::from_array(tensor) { + Ok(tensor) => tensor, + Err(_) => return Err(SurrealError::new("Failed to convert tensor to ort tensor".to_string(), SurrealErrorStatus::Unknown)) + }; + let x = match ort::inputs![tensor] { + Ok(x) => x, + Err(_) => return Err(SurrealError::new("Failed to create input tensor".to_string(), SurrealErrorStatus::Unknown)) + }; + let outputs = safe_eject!(session.run(x), SurrealErrorStatus::Unknown); + + let mut buffer: Vec = Vec::new(); + + // extract the output tensor converting the values to f32 if they are i64 + match outputs[0].try_extract_tensor::() { + Ok(y) => { + for i in y.view().clone().into_iter() { + buffer.push(*i); + } + }, + Err(_) => { + for i in safe_eject!(outputs[0].try_extract_tensor::(), SurrealErrorStatus::Unknown).view().clone().into_iter() { + buffer.push(*i as f32); + } + } + }; + return Ok(buffer) + } + + /// Checks the header applying normalisers if present and then performs a raw computation on the loaded model. Will + /// also apply inverse normalisers if present on the outputs. + /// + /// # Notes + /// This function is fairly coupled and will consider breaking out the functions later on if needed. + /// + /// # Arguments + /// * `input_values` - A hashmap of keys and values that will be used to create the input tensor. + /// + /// # Returns + /// The computed output tensor from the loaded model. + pub fn buffered_compute(&self, input_values: &mut HashMap) -> Result, SurrealError> { + // applying normalisers if present + for (key, value) in &mut *input_values { + let value_ref = value.clone(); + match self.surml_file.header.get_normaliser(&key.to_string())? { + Some(normaliser) => { + *value = normaliser.normalise(value_ref); + }, + None => {} + } + } + let tensor = self.input_tensor_from_key_bindings(input_values.clone())?; + let output = self.raw_compute(tensor, None)?; + + // if no normaliser is present, return the output + if self.surml_file.header.output.normaliser == None { + return Ok(output) + } + + // apply the normaliser to the output + let output_normaliser = match self.surml_file.header.output.normaliser.as_ref() { + Some(normaliser) => normaliser, + None => return Err(SurrealError::new( + String::from("No normaliser present for output which shouldn't happen as passed initial check for").to_string(), + SurrealErrorStatus::Unknown + )) + }; + let mut buffer = Vec::with_capacity(output.len()); + + for value in output { + buffer.push(output_normaliser.inverse_normalise(value)); + } + return Ok(buffer) + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[cfg(feature = "sklearn-tests")] + #[test] + fn test_raw_compute_linear_sklearn() { + let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap(); + + let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap(); + assert_eq!(output.len(), 1); + assert_eq!(output[0], 985.57745); + } + + #[cfg(feature = "sklearn-tests")] + #[test] + fn test_buffered_compute_linear_sklearn() { + let mut file = SurMlFile::from_file("./model_stash/sklearn/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let output = model_computation.buffered_compute(&mut input_values).unwrap(); + assert_eq!(output.len(), 1); + } + + #[cfg(feature = "onnx-tests")] + #[test] + fn test_raw_compute_linear_onnx() { + let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap(); + + let output = model_computation.raw_compute(raw_input, Some((1, 2))).unwrap(); + assert_eq!(output.len(), 1); + assert_eq!(output[0], 985.57745); + } + + #[cfg(feature = "onnx-tests")] + #[test] + fn test_buffered_compute_linear_onnx() { + let mut file = SurMlFile::from_file("./model_stash/onnx/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let output = model_computation.buffered_compute(&mut input_values).unwrap(); + assert_eq!(output.len(), 1); + } + + #[cfg(feature = "torch-tests")] + #[test] + fn test_raw_compute_linear_torch() { + let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap(); + + let output = model_computation.raw_compute(raw_input, None).unwrap(); + assert_eq!(output.len(), 1); + } + + #[cfg(feature = "torch-tests")] + #[test] + fn test_buffered_compute_linear_torch() { + let mut file = SurMlFile::from_file("./model_stash/torch/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let output = model_computation.buffered_compute(&mut input_values).unwrap(); + assert_eq!(output.len(), 1); + } + + #[cfg(feature = "tensorflow-tests")] + #[test] + fn test_raw_compute_linear_tensorflow() { + let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap(); + + let output = model_computation.raw_compute(raw_input, None).unwrap(); + assert_eq!(output.len(), 1); + } + + #[cfg(feature = "tensorflow-tests")] + #[test] + fn test_buffered_compute_linear_tensorflow() { + let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let output = model_computation.buffered_compute(&mut input_values).unwrap(); + assert_eq!(output.len(), 1); + } +} diff --git a/modules/c-wrapper/build-context/core/src/execution/mod.rs b/modules/c-wrapper/build-context/core/src/execution/mod.rs new file mode 100644 index 0000000..5b9d66c --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/execution/mod.rs @@ -0,0 +1,4 @@ +//! Defines operations around performing computations on a loaded model. +pub mod compute; +// pub mod onnx_environment; +pub mod session; diff --git a/modules/c-wrapper/build-context/core/src/execution/onnx_environment.rs b/modules/c-wrapper/build-context/core/src/execution/onnx_environment.rs new file mode 100644 index 0000000..ded5add --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/execution/onnx_environment.rs @@ -0,0 +1,93 @@ +//! This module defines the ONNX environment for the execution of ONNX models. +use once_cell::sync::Lazy; +use ort::{Environment, ExecutionProvider}; +use std::sync::Arc; + +// Compiles the ONNX module into the rust binary. +#[cfg(all( + target_os = "macos", + not(doc), + not(onnx_runtime_env_var_set), + not(onnx_statically_linked) +))] +pub static LIB_BYTES: &'static [u8] = include_bytes!("../../libonnxruntime.dylib"); + +#[cfg(all( + any(target_os = "linux", target_os = "android"), + not(doc), + not(onnx_runtime_env_var_set), + not(onnx_statically_linked) +))] +pub static LIB_BYTES: &'static [u8] = include_bytes!("../../libonnxruntime.so"); + +#[cfg(all( + target_os = "windows", + not(doc), + not(onnx_runtime_env_var_set), + not(onnx_statically_linked) +))] +pub static LIB_BYTES: &'static [u8] = include_bytes!("../../libonnxruntime.dll"); + +// Fallback for documentation and other targets +#[cfg(any( + doc, + onnx_runtime_env_var_set, + onnx_statically_linked, + not(any( + target_os = "macos", + target_os = "linux", + target_os = "android", + target_os = "windows" + )) +))] +pub static LIB_BYTES: &'static [u8] = &[]; + +// the ONNX environment which loads the library +pub static ENVIRONMENT: Lazy> = Lazy::new(|| { + if cfg!(onnx_statically_linked) { + return Arc::new( + Environment::builder() + .with_execution_providers([ExecutionProvider::CPU(Default::default())]) + .build() + .unwrap(), + ); + } + + // if the "ONNXRUNTIME_LIB_PATH" is provided we do not need to compile the ONNX library, instead we just point to the library + // in the "ONNXRUNTIME_LIB_PATH" and load that. + match std::env::var("ONNXRUNTIME_LIB_PATH") { + Ok(path) => { + std::env::set_var("ORT_DYLIB_PATH", path); + return Arc::new( + Environment::builder() + .with_execution_providers([ExecutionProvider::CPU(Default::default())]) + .build() + .unwrap(), + ); + } + // if the "ONNXRUNTIME_LIB_PATH" is not provided we use the `LIB_BYTES` that is the ONNX library compiled into the binary. + // we write the `LIB_BYTES` to a temporary file and then load that file. + Err(_) => { + let current_dir = std::env::current_dir().unwrap(); + let current_dir = current_dir.to_str().unwrap(); + let write_dir = std::path::Path::new(current_dir).join("libonnxruntime.dylib"); + + #[cfg(any(not(doc), not(onnx_runtime_env_var_set)))] + let _ = std::fs::write(write_dir.clone(), LIB_BYTES); + + std::env::set_var("ORT_DYLIB_PATH", write_dir.clone()); + let environment = Arc::new( + Environment::builder() + .with_execution_providers([ExecutionProvider::CPU(Default::default())]) + .build() + .unwrap(), + ); + std::env::remove_var("ORT_DYLIB_PATH"); + + #[cfg(any(not(doc), not(onnx_runtime_env_var_set)))] + let _ = std::fs::remove_file(write_dir); + + return environment; + } + } +}); diff --git a/modules/c-wrapper/build-context/core/src/execution/session.rs b/modules/c-wrapper/build-context/core/src/execution/session.rs new file mode 100644 index 0000000..69a847b --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/execution/session.rs @@ -0,0 +1,58 @@ +//! Defines the session module for the execution module. +use ort::session::Session; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use crate::safe_eject; + +#[cfg(feature = "dynamic")] +use once_cell::sync::Lazy; +#[cfg(feature = "dynamic")] +use ort::environment::{EnvironmentBuilder, Environment}; +#[cfg(feature = "dynamic")] +use std::sync::{Arc, Mutex}; + +use std::sync::LazyLock; + + +/// Creates a session for a model. +/// +/// # Arguments +/// * `model_bytes` - The model bytes (usually extracted fromt the surml file) +/// +/// # Returns +/// A session object. +pub fn get_session(model_bytes: Vec) -> Result { + let builder = safe_eject!(Session::builder(), SurrealErrorStatus::Unknown); + + #[cfg(feature = "gpu")] + { + let cuda = CUDAExecutionProvider::default(); + if let Err(e) = cuda.register(&builder) { + eprintln!("Failed to register CUDA: {:?}. Falling back to CPU.", e); + } else { + println!("CUDA registered successfully"); + } + } + let session: Session = safe_eject!(builder + .commit_from_memory(&model_bytes), SurrealErrorStatus::Unknown); + Ok(session) +} + + +// #[cfg(feature = "dynamic")] +// pub static ORT_ENV: LazyLock>>>> = LazyLock::new(|| Arc::new(Mutex::new(None))); + + +#[cfg(feature = "dynamic")] +pub fn set_environment(dylib_path: String) -> Result<(), SurrealError> { + + let outcome: EnvironmentBuilder = ort::init_from(dylib_path); + match outcome.commit() { + Ok(env) => { + // ORT_ENV.lock().unwrap().replace(env); + }, + Err(e) => { + return Err(SurrealError::new(e.to_string(), SurrealErrorStatus::Unknown)); + } + } + Ok(()) +} diff --git a/modules/c-wrapper/build-context/core/src/lib.rs b/modules/c-wrapper/build-context/core/src/lib.rs new file mode 100644 index 0000000..5ec0080 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/lib.rs @@ -0,0 +1,98 @@ +//! An embedded ONNX runtime directly in the Rust binary when compiling result in no need for installing ONNX runtime separately +//! or worrying about version clashes with other runtimes. +//! +//! This crate is just the Rust implementation of the Surml API. It is advised that you just use this crate directly if you are running +//! a Rust server. It must be noted that the version of ONNX needs to be the same as the client when using this crate. For this current +//! version of Surml, the ONNX version is `1.16.0`. +//! +//! ## Compilation config +//! If nothing is configured the crate will compiled the ONNX runtime into the binary. This is the default behaviour. However, if you +//! want to use an ONNX runtime that is installed on your system, you can set the environment variable `ONNXRUNTIME_LIB_PATH` before +//! you compile the crate. This will make the crate use the ONNX runtime that is installed on your system. +//! +//! ## Usage +//! Surml can be used to store, load, and execute ONNX models. +//! +//! ### Storing and accessing models +//! We can store models and meta data around the models with the following code: +//! ```rust +//! use std::fs::File; +//! use std::io::{self, Read, Write}; +//! +//! use surrealml_core::storage::surml_file::SurMlFile; +//! use surrealml_core::storage::header::Header; +//! use surrealml_core::storage::header::normalisers::{ +//! wrapper::NormaliserType, +//! linear_scaling::LinearScaling +//! }; +//! +//! +//! // load your own model here (surrealml python package can be used to convert PyTorch, +//! // and Sklearn models to ONNX or package them as surml files) +//! let mut file = File::open("./stash/linear_test.onnx").unwrap(); +//! let mut model_bytes = Vec::new(); +//! file.read_to_end(&mut model_bytes).unwrap(); +//! +//! // create a header for the model +//! let mut header = Header::fresh(); +//! header.add_column(String::from("squarefoot")); +//! header.add_column(String::from("num_floors")); +//! header.add_output(String::from("house_price"), None); +//! +//! // add normalisers if needed +//! header.add_normaliser( +//! "squarefoot".to_string(), +//! NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }) +//! ); +//! header.add_normaliser( +//! "num_floors".to_string(), +//! NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }) +//! ); +//! +//! // create a surml file +//! let surml_file = SurMlFile::new(header, model_bytes); +//! +//! // read and write surml files +//! surml_file.write("./stash/test.surml").unwrap(); +//! let new_file = SurMlFile::from_file("./stash/test.surml").unwrap(); +//! let file_from_bytes = SurMlFile::from_bytes(surml_file.to_bytes()).unwrap(); +//! ``` +//! +//! ### Executing models +//! We you load a `surml` file, you can execute the model with the following code: +//! ```no_run +//! use surrealml_core::storage::surml_file::SurMlFile; +//! use surrealml_core::execution::compute::ModelComputation; +//! use ndarray::ArrayD; +//! use std::collections::HashMap; +//! +//! +//! let mut file = SurMlFile::from_file("./stash/test.surml").unwrap(); +//! +//! let compute_unit = ModelComputation { +//! surml_file: &mut file, +//! }; +//! +//! // automatically map inputs and apply normalisers to the compute if this data was put in the header +//! let mut input_values = HashMap::new(); +//! input_values.insert(String::from("squarefoot"), 1000.0); +//! input_values.insert(String::from("num_floors"), 2.0); +//! +//! let output = compute_unit.buffered_compute(&mut input_values).unwrap(); +//! +//! // feed a raw ndarray into the model if no header was provided or if you want to bypass the header +//! let x = vec![1000.0, 2.0]; +//! let data: ArrayD = ndarray::arr1(&x).into_dyn(); +//! +//! // None input can be a tuple of dimensions of the input data +//! let output = compute_unit.raw_compute(data, None).unwrap(); +//! ``` +pub mod storage; +pub mod execution; +pub mod errors; + + +/// Returns the version of the ONNX runtime that is used. +pub fn onnx_runtime() -> &'static str { + "1.20.0" +} diff --git a/modules/c-wrapper/build-context/core/src/storage/header/engine.rs b/modules/c-wrapper/build-context/core/src/storage/header/engine.rs new file mode 100644 index 0000000..6dfbc28 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/engine.rs @@ -0,0 +1,55 @@ +//! Defines the placeholder for the type of model engine in the header. + + +/// Defines the type of engine being used to run the model. +/// +/// # Fields +/// * `Native` - The native engine which will be native rust and linfa. +/// * `PyTorch` - The PyTorch engine which will be PyTorch and tch-rs. +/// * `Undefined` - The undefined engine which will be used when the engine is not defined. +#[derive(Debug, PartialEq)] +pub enum Engine { + Native, + PyTorch, + Undefined +} + + +impl Engine { + + /// Creates a new `Engine` struct with the undefined engine. + /// + /// # Returns + /// A new `Engine` struct with the undefined engine. + pub fn fresh() -> Self { + Engine::Undefined + } + + /// Creates a new `Engine` struct from a string. + /// + /// # Arguments + /// * `engine` - The engine as a string. + /// + /// # Returns + /// A new `Engine` struct. + pub fn from_string(engine: String) -> Self { + match engine.as_str() { + "native" => Engine::Native, + "pytorch" => Engine::PyTorch, + _ => Engine::Undefined, + } + } + + /// Translates the struct to a string. + /// + /// # Returns + /// * `String` - The struct as a string. + pub fn to_string(&self) -> String { + match self { + Engine::Native => "native".to_string(), + Engine::PyTorch => "pytorch".to_string(), + Engine::Undefined => "".to_string(), + } + } + +} diff --git a/modules/c-wrapper/build-context/core/src/storage/header/input_dims.rs b/modules/c-wrapper/build-context/core/src/storage/header/input_dims.rs new file mode 100644 index 0000000..4f1b87d --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/input_dims.rs @@ -0,0 +1,82 @@ +//! InputDims is a struct that holds the dimensions of the input tensors for the model. + + +/// InputDims is a struct that holds the dimensions of the input tensors for the model. +/// +/// # Fields +/// * `dims` - The dimensions of the input tensors. +#[derive(Debug, PartialEq)] +pub struct InputDims { + pub dims: [i32; 2], +} + + +impl InputDims { + + /// Creates a new `InputDims` struct with all zeros. + /// + /// # Returns + /// A new `InputDims` struct with all zeros. + pub fn fresh() -> Self { + InputDims { + dims: [0, 0], + } + } + + /// Creates a new `InputDims` struct from a string. + /// + /// # Arguments + /// * `data` - The dimensions as a string. + /// + /// # Returns + /// A new `InputDims` struct. + pub fn from_string(data: String) -> InputDims { + if data == "".to_string() { + return InputDims::fresh(); + } + let dims: Vec<&str> = data.split(",").collect(); + let dims: Vec = dims.iter().map(|x| x.parse::().unwrap()).collect(); + InputDims { + dims: [dims[0], dims[1]], + } + } + + /// Translates the struct to a string. + /// + /// # Returns + /// * `String` - The struct as a string. + pub fn to_string(&self) -> String { + if self.dims[0] == 0 && self.dims[1] == 0 { + return "".to_string(); + } + format!("{},{}", self.dims[0], self.dims[1]) + } +} + + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + fn test_fresh() { + let input_dims = InputDims::fresh(); + assert_eq!(input_dims.dims[0], 0); + assert_eq!(input_dims.dims[1], 0); + } + + #[test] + fn test_from_string() { + let input_dims = InputDims::from_string("1,2".to_string()); + assert_eq!(input_dims.dims[0], 1); + assert_eq!(input_dims.dims[1], 2); + } + + #[test] + fn test_to_string() { + let input_dims = InputDims::from_string("1,2".to_string()); + assert_eq!(input_dims.to_string(), "1,2".to_string()); + } + +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/keys.rs b/modules/c-wrapper/build-context/core/src/storage/header/keys.rs new file mode 100644 index 0000000..200c1f6 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/keys.rs @@ -0,0 +1,211 @@ +//! Defines the key bindings for input data. +use std::collections::HashMap; + +use crate::safe_eject_internal; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; + + +/// Defines the key bindings for input data. +/// +/// # Fields +/// * `store` - A vector of strings that represent the column names. The order of this store is the same as the order +/// in which the columns are expected in the input data. +/// * `reference` - A hashmap that maps the column names to their index in the `self.store` field. +#[derive(Debug, PartialEq)] +pub struct KeyBindings { + pub store: Vec, + pub reference: HashMap, +} + + +impl KeyBindings { + + /// Creates a new key bindings with no columns. + /// + /// # Returns + /// A new key bindings with no columns. + pub fn fresh() -> Self { + KeyBindings { + store: Vec::new(), + reference: HashMap::new(), + } + } + + /// Adds a column name to the `self.store` field. It must be noted that the order in which the columns are added is + /// the order in which they will be expected in the input data. + /// + /// # Arguments + /// * `column_name` - The name of the column to be added. + pub fn add_column(&mut self, column_name: String) { + let index = self.store.len(); + self.store.push(column_name.clone()); + self.reference.insert(column_name, index); + } + + /// Constructs the key bindings from a string. + /// + /// # Arguments + /// * `data` - The string to be converted into key bindings. + /// + /// # Returns + /// The key bindings constructed from the string. + pub fn from_string(data: String) -> Self { + if data.len() == 0 { + return KeyBindings::fresh() + } + let mut store = Vec::new(); + let mut reference = HashMap::new(); + + let lines = data.split("=>"); + let mut count = 0; + + for line in lines { + store.push(line.to_string()); + reference.insert(line.to_string(), count); + count += 1; + } + KeyBindings { + store, + reference, + } + } + + /// converts the key bindings to a string. + /// + /// # Returns + /// The key bindings as a string. + pub fn to_string(&self) -> String { + self.store.join("=>") + } + + /// Constructs the key bindings from bytes. + /// + /// # Arguments + /// * `data` - The bytes to be converted into key bindings. + /// + /// # Returns + /// The key bindings constructed from the bytes. + pub fn from_bytes(data: &[u8]) -> Result { + let data = safe_eject_internal!(String::from_utf8(data.to_vec())); + Ok(Self::from_string(data)) + } + + /// Converts the key bindings to bytes. + /// + /// # Returns + /// The key bindings as bytes. + pub fn to_bytes(&self) -> Vec { + self.to_string().into_bytes() + } + + + +} + + +#[cfg(test)] +pub mod tests { + + use super::*; + + pub fn generate_string() -> String { + "a=>b=>c=>d=>e=>f".to_string() + } + + pub fn generate_bytes() -> Vec { + "a=>b=>c=>d=>e=>f".to_string().into_bytes() + } + + fn generate_struct() -> KeyBindings { + let store = vec!["a".to_string(), "b".to_string(), "c".to_string(), "d".to_string(), "e".to_string(), "f".to_string()]; + let mut reference = HashMap::new(); + reference.insert("a".to_string(), 0); + reference.insert("b".to_string(), 1); + reference.insert("c".to_string(), 2); + reference.insert("d".to_string(), 3); + reference.insert("e".to_string(), 4); + reference.insert("f".to_string(), 5); + KeyBindings { + store, + reference, + } + } + + #[test] + fn test_from_string_with_empty_string() { + let data = "".to_string(); + let bindings = KeyBindings::from_string(data); + assert_eq!(bindings.store.len(), 0); + assert_eq!(bindings.reference.len(), 0); + } + + #[test] + fn test_from_string() { + let data = generate_string(); + let bindings = KeyBindings::from_string(data); + assert_eq!(bindings.store[0], "a"); + assert_eq!(bindings.store[1], "b"); + assert_eq!(bindings.store[2], "c"); + assert_eq!(bindings.store[3], "d"); + assert_eq!(bindings.store[4], "e"); + assert_eq!(bindings.store[5], "f"); + + assert_eq!(bindings.reference["a"], 0); + assert_eq!(bindings.reference["b"], 1); + assert_eq!(bindings.reference["c"], 2); + assert_eq!(bindings.reference["d"], 3); + assert_eq!(bindings.reference["e"], 4); + assert_eq!(bindings.reference["f"], 5); + } + + #[test] + fn test_to_string() { + let bindings = generate_struct(); + let data = bindings.to_string(); + assert_eq!(data, generate_string()); + } + + #[test] + fn test_from_bytes() { + let data = generate_bytes(); + let bindings = KeyBindings::from_bytes(&data).unwrap(); + assert_eq!(bindings.store[0], "a"); + assert_eq!(bindings.store[1], "b"); + assert_eq!(bindings.store[2], "c"); + assert_eq!(bindings.store[3], "d"); + assert_eq!(bindings.store[4], "e"); + assert_eq!(bindings.store[5], "f"); + + assert_eq!(bindings.reference["a"], 0); + assert_eq!(bindings.reference["b"], 1); + assert_eq!(bindings.reference["c"], 2); + assert_eq!(bindings.reference["d"], 3); + assert_eq!(bindings.reference["e"], 4); + assert_eq!(bindings.reference["f"], 5); + } + + #[test] + fn test_to_bytes() { + let bindings = generate_struct(); + let data = bindings.to_bytes(); + assert_eq!(data, generate_bytes()); + } + + #[test] + fn test_add_column() { + let mut bindings = generate_struct(); + bindings.add_column("g".to_string()); + assert_eq!(bindings.store[6], "g"); + assert_eq!(bindings.reference["g"], 6); + + let mut bindings = KeyBindings::fresh(); + bindings.add_column("a".to_string()); + bindings.add_column("b".to_string()); + + assert_eq!(bindings.store[0], "a"); + assert_eq!(bindings.reference["a"], 0); + assert_eq!(bindings.store[1], "b"); + assert_eq!(bindings.reference["b"], 1); + } + +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/mod.rs b/modules/c-wrapper/build-context/core/src/storage/header/mod.rs new file mode 100644 index 0000000..7c19f34 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/mod.rs @@ -0,0 +1,368 @@ +//! Handles the loading, saving, and utilisation of all the data in the header of the model file. +pub mod keys; +pub mod normalisers; +pub mod output; +pub mod string_value; +pub mod version; +pub mod engine; +pub mod origin; +pub mod input_dims; + +use keys::KeyBindings; +use normalisers::wrapper::NormaliserType; +use normalisers::NormaliserMap; +use output::Output; +use string_value::StringValue; +use version::Version; +use engine::Engine; +use origin::Origin; +use input_dims::InputDims; +use crate::safe_eject; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; + + +/// The header of the model file. +/// +/// # Fields +/// * `keys` - The key bindings where the order of the input columns is stored. +/// * `normalisers` - The normalisers where the normalisation functions are stored per column if there are any. +/// * `output` - The output where the output column name and normaliser are stored if there are any. +/// * `name` - The name of the model. +/// * `version` - The version of the model. +/// * `description` - The description of the model. +/// * `engine` - The engine of the model (could be native or pytorch). +/// * `origin` - The origin of the model which is where the model was created and who the author is. +#[derive(Debug, PartialEq)] +pub struct Header { + pub keys: KeyBindings, + pub normalisers: NormaliserMap, + pub output: Output, + pub name: StringValue, + pub version: Version, + pub description: StringValue, + pub engine: Engine, + pub origin: Origin, + pub input_dims: InputDims, +} + + +impl Header { + + /// Creates a new header with no columns or normalisers. + /// + /// # Returns + /// A new header with no columns or normalisers. + pub fn fresh() -> Self { + Header { + keys: KeyBindings::fresh(), + normalisers: NormaliserMap::fresh(), + output: Output::fresh(), + name: StringValue::fresh(), + version: Version::fresh(), + description: StringValue::fresh(), + engine: Engine::fresh(), + origin: Origin::fresh(), + input_dims: InputDims::fresh(), + } + } + + /// Adds a model name to the `self.name` field. + /// + /// # Arguments + /// * `model_name` - The name of the model to be added. + pub fn add_name(&mut self, model_name: String) { + self.name = StringValue::from_string(model_name); + } + + /// Adds a version to the `self.version` field. + /// + /// # Arguments + /// * `version` - The version to be added. + pub fn add_version(&mut self, version: String) -> Result<(), SurrealError> { + self.version = Version::from_string(version)?; + Ok(()) + } + + /// Adds a description to the `self.description` field. + /// + /// # Arguments + /// * `description` - The description to be added. + pub fn add_description(&mut self, description: String) { + self.description = StringValue::from_string(description); + } + + /// Adds a column name to the `self.keys` field. It must be noted that the order in which the columns are added is + /// the order in which they will be expected in the input data. We can do this with the followng example: + /// + /// # Arguments + /// * `column_name` - The name of the column to be added. + pub fn add_column(&mut self, column_name: String) { + self.keys.add_column(column_name); + } + + /// Adds a normaliser to the `self.normalisers` field. + /// + /// # Arguments + /// * `column_name` - The name of the column to which the normaliser will be applied. + /// * `normaliser` - The normaliser to be applied to the column. + pub fn add_normaliser(&mut self, column_name: String, normaliser: NormaliserType) -> Result<(), SurrealError> { + let _ = self.normalisers.add_normaliser(normaliser, column_name, &self.keys)?; + Ok(()) + } + + /// Gets the normaliser for a given column name. + /// + /// # Arguments + /// * `column_name` - The name of the column to which the normaliser will be applied. + /// + /// # Returns + /// The normaliser for the given column name. + pub fn get_normaliser(&self, column_name: &String) -> Result, SurrealError> { + self.normalisers.get_normaliser(column_name.to_string(), &self.keys) + } + + /// Adds an output column to the `self.output` field. + /// + /// # Arguments + /// * `column_name` - The name of the column to be added. + /// * `normaliser` - The normaliser to be applied to the column. + pub fn add_output(&mut self, column_name: String, normaliser: Option) { + self.output.name = Some(column_name); + self.output.normaliser = normaliser; + } + + /// Adds an engine to the `self.engine` field. + /// + /// # Arguments + /// * `engine` - The engine to be added. + pub fn add_engine(&mut self, engine: String) { + self.engine = Engine::from_string(engine); + } + + /// Adds an author to the `self.origin` field. + /// + /// # Arguments + /// * `author` - The author to be added. + pub fn add_author(&mut self, author: String) { + self.origin.add_author(author); + } + + /// Adds an origin to the `self.origin` field. + /// + /// # Arguments + /// * `origin` - The origin to be added. + pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> { + self.origin.add_origin(origin) + } + + /// The standard delimiter used to seperate each field in the header. + fn delimiter() -> &'static str { + "//=>" + } + + /// Constructs the `Header` struct from bytes. + /// + /// # Arguments + /// * `data` - The bytes to be converted into a `Header` struct. + /// + /// # Returns + /// The `Header` struct. + pub fn from_bytes(data: Vec) -> Result { + + let string_data = safe_eject!(String::from_utf8(data), SurrealErrorStatus::BadRequest); + + let buffer = string_data.split(Self::delimiter()).collect::>(); + + let keys: KeyBindings = KeyBindings::from_string(buffer.get(1).unwrap_or(&"").to_string()); + let normalisers = NormaliserMap::from_string(buffer.get(2).unwrap_or(&"").to_string(), &keys)?; + let output = Output::from_string(buffer.get(3).unwrap_or(&"").to_string())?; + let name = StringValue::from_string(buffer.get(4).unwrap_or(&"").to_string()); + let version = Version::from_string(buffer.get(5).unwrap_or(&"").to_string())?; + let description = StringValue::from_string(buffer.get(6).unwrap_or(&"").to_string()); + let engine = Engine::from_string(buffer.get(7).unwrap_or(&"").to_string()); + let origin = Origin::from_string(buffer.get(8).unwrap_or(&"").to_string())?; + let input_dims = InputDims::from_string(buffer.get(9).unwrap_or(&"").to_string()); + Ok(Header {keys, normalisers, output, name, version, description, engine, origin, input_dims}) + } + + /// Converts the `Header` struct into bytes. + /// + /// # Returns + /// A tuple containing the number of bytes in the header and the bytes themselves. + pub fn to_bytes(&self) -> (i32, Vec) { + let buffer = vec![ + "".to_string(), + self.keys.to_string(), + self.normalisers.to_string(), + self.output.to_string(), + self.name.to_string(), + self.version.to_string(), + self.description.to_string(), + self.engine.to_string(), + self.origin.to_string(), + self.input_dims.to_string(), + "".to_string(), + ]; + let buffer = buffer.join(Self::delimiter()).into_bytes(); + (buffer.len() as i32, buffer) + } +} + + +#[cfg(test)] +mod tests { + + use super::*; + use super::keys::tests::generate_string as generate_key_string; + use super::normalisers::tests::generate_string as generate_normaliser_string; + use super::normalisers::{ + clipping::Clipping, + linear_scaling::LinearScaling, + log_scale::LogScaling, + z_score::ZScore, + }; + + + pub fn generate_string() -> String { + let keys = generate_key_string(); + let normalisers = generate_normaliser_string(); + let output = "g=>linear_scaling(0.0,1.0)".to_string(); + format!( + "{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}{}", + Header::delimiter(), + keys, + Header::delimiter(), + normalisers, + Header::delimiter(), + output, + Header::delimiter(), + "test model name".to_string(), + Header::delimiter(), + "0.0.1".to_string(), + Header::delimiter(), + "test description".to_string(), + Header::delimiter(), + Engine::PyTorch.to_string(), + Header::delimiter(), + Origin::from_string("author=>local".to_string()).unwrap().to_string(), + Header::delimiter(), + InputDims::from_string("1,2".to_string()).to_string(), + Header::delimiter(), + ) + } + + pub fn generate_bytes() -> Vec { + generate_string().into_bytes() + } + + #[test] + fn test_from_bytes() { + let header = Header::from_bytes(generate_bytes()).unwrap(); + + assert_eq!(header.keys.store.len(), 6); + assert_eq!(header.keys.reference.len(), 6); + assert_eq!(header.normalisers.store.len(), 4); + + assert_eq!(header.keys.store[0], "a"); + assert_eq!(header.keys.store[1], "b"); + assert_eq!(header.keys.store[2], "c"); + assert_eq!(header.keys.store[3], "d"); + assert_eq!(header.keys.store[4], "e"); + assert_eq!(header.keys.store[5], "f"); + } + + #[test] + fn test_empty_header() { + let string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string(); + let data = string.as_bytes(); + let header = Header::from_bytes(data.to_vec()).unwrap(); + + assert_eq!(header, Header::fresh()); + + let string = "".to_string(); + let data = string.as_bytes(); + let header = Header::from_bytes(data.to_vec()).unwrap(); + + assert_eq!(header, Header::fresh()); + } + + #[test] + fn test_to_bytes() { + let header = Header::from_bytes(generate_bytes()).unwrap(); + let (bytes_num, bytes) = header.to_bytes(); + let string = String::from_utf8(bytes).unwrap(); + + // below the integers are correct but there is a difference with the decimal point representation in the string, we can alter this + // fairly easy and will investigate it + let expected_string = "//=>a=>b=>c=>d=>e=>f//=>a=>linear_scaling(0,1)//b=>clipping(0,1.5)//c=>log_scaling(10,0)//e=>z_score(0,1)//=>g=>linear_scaling(0,1)//=>test model name//=>0.0.1//=>test description//=>pytorch//=>author=>local//=>1,2//=>".to_string(); + + assert_eq!(string, expected_string); + assert_eq!(bytes_num, expected_string.len() as i32); + + let empty_header = Header::fresh(); + let (bytes_num, bytes) = empty_header.to_bytes(); + let string = String::from_utf8(bytes).unwrap(); + let expected_string = "//=>//=>//=>//=>//=>//=>//=>//=>//=>//=>".to_string(); + + assert_eq!(string, expected_string); + assert_eq!(bytes_num, expected_string.len() as i32); + } + + #[test] + fn test_add_column() { + let mut header = Header::fresh(); + header.add_column("a".to_string()); + header.add_column("b".to_string()); + header.add_column("c".to_string()); + header.add_column("d".to_string()); + header.add_column("e".to_string()); + header.add_column("f".to_string()); + + assert_eq!(header.keys.store.len(), 6); + assert_eq!(header.keys.reference.len(), 6); + + assert_eq!(header.keys.store[0], "a"); + assert_eq!(header.keys.store[1], "b"); + assert_eq!(header.keys.store[2], "c"); + assert_eq!(header.keys.store[3], "d"); + assert_eq!(header.keys.store[4], "e"); + assert_eq!(header.keys.store[5], "f"); + } + + #[test] + fn test_add_normalizer() { + let mut header = Header::fresh(); + header.add_column("a".to_string()); + header.add_column("b".to_string()); + header.add_column("c".to_string()); + header.add_column("d".to_string()); + header.add_column("e".to_string()); + header.add_column("f".to_string()); + + let _ = header.add_normaliser( + "a".to_string(), + NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 }) + ); + let _ = header.add_normaliser( + "b".to_string(), + NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) }) + ); + let _ = header.add_normaliser( + "c".to_string(), + NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 }) + ); + let _ = header.add_normaliser( + "e".to_string(), + NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 }) + ); + + assert_eq!(header.normalisers.store.len(), 4); + assert_eq!(header.normalisers.store[0], NormaliserType::LinearScaling(LinearScaling { min: 0.0, max: 1.0 })); + assert_eq!(header.normalisers.store[1], NormaliserType::Clipping(Clipping { min: Some(0.0), max: Some(1.5) })); + assert_eq!(header.normalisers.store[2], NormaliserType::LogScaling(LogScaling { base: 10.0, min: 0.0 })); + assert_eq!(header.normalisers.store[3], NormaliserType::ZScore(ZScore { mean: 0.0, std_dev: 1.0 })); + } + +} + + diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/clipping.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/clipping.rs new file mode 100644 index 0000000..e93dc30 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/clipping.rs @@ -0,0 +1,118 @@ +//! The functionality and parameters around a clipping normaliser. +use super::traits::Normaliser; + + +/// A clipping normaliser. +/// +/// # Fields +/// * `min` - The minimum value to clip to. +/// * `max` - The maximum value to clip to. +#[derive(Debug, PartialEq)] +pub struct Clipping { + pub min: Option, + pub max: Option, +} + + +impl Normaliser for Clipping { + + /// Normalises a value. + /// + /// # Arguments + /// * `input` - The value to normalise. + /// + /// # Returns + /// The normalised value. + fn normalise(&self, input: f32)-> f32 { + let normalised = match (self.min, self.max) { + (Some(min), Some(max)) => { + if input < min { + min + } else if input > max { + max + } else { + input + } + }, + (Some(min), None) => { + if input < min { + min + } else { + input + } + }, + (None, Some(max)) => { + if input > max { + max + } else { + input + } + }, + (None, None) => { + input + }, + }; + normalised + } + + fn key() -> String { + "clipping".to_string() + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_normalise_with_both_bounds() { + let normaliser = Clipping { + min: Some(0.0), + max: Some(1.0), + }; + let input = 0.5; + let expected = 0.5; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + + #[test] + fn test_normalise_with_min_bound() { + let normaliser = Clipping { + min: Some(0.0), + max: None, + }; + let input = -0.5; + let expected = 0.0; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + + #[test] + fn test_normalise_with_max_bound() { + let normaliser = Clipping { + min: None, + max: Some(1.0), + }; + let input = 1.5; + let expected = 1.0; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + + #[test] + fn test_normalise_with_no_bounds() { + let normaliser = Clipping { + min: None, + max: None, + }; + let input = 0.5; + let expected = 0.5; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/linear_scaling.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/linear_scaling.rs new file mode 100644 index 0000000..b424d4a --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/linear_scaling.rs @@ -0,0 +1,74 @@ +//! The functionality and parameters around a linear scaling normaliser. +use super::traits::Normaliser; + + +/// A linear scaling normaliser. +/// +/// # Fields +/// * `min` - The minimum value to scale to. +/// * `max` - The maximum value to scale to. +#[derive(Debug, PartialEq)] +pub struct LinearScaling { + pub min: f32, + pub max: f32, +} + + +impl Normaliser for LinearScaling { + + /// Normalises a value. + /// + /// # Arguments + /// * `input` - The value to normalise. + /// + /// # Returns + /// The normalised value. + fn normalise(&self, input: f32)-> f32 { + let range = self.max - self.min; + let normalised = (input - self.min) / range; + normalised + } + + /// Applies the inverse of the value for the normaliser. + /// + /// # Arguments + /// * `input` - The value to inverse normalise. + /// + /// # Returns + /// The inverse normalised value. + fn inverse_normalise(&self, input: f32) -> f32 { + let range = self.max - self.min; + let denormalised = (input * range) + self.min; + denormalised + } + + + /// The key of the normaliser. + /// + /// # Returns + /// The key of the normaliser. + fn key() -> String { + "linear_scaling".to_string() + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_normalise_with_both_bounds() { + let normaliser = LinearScaling { + min: 0.0, + max: 100.0, + }; + let input = 50.0; + let expected = 0.5; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/log_scale.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/log_scale.rs new file mode 100644 index 0000000..f3dad7a --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/log_scale.rs @@ -0,0 +1,71 @@ +//! The functionality and parameters around a log scaling normaliser. +use super::traits::Normaliser; + + +/// A log scaling normaliser. +/// +/// # Fields +/// * `base` - The base of the logarithm. +/// * `min` - The minimum value to scale to. +#[derive(Debug, PartialEq)] +pub struct LogScaling { + pub base: f32, + pub min: f32, +} + + +impl Normaliser for LogScaling { + + /// Normalises a value. + /// + /// # Arguments + /// * `input` - The value to normalise. + /// + /// # Returns + /// The normalised value. + fn normalise(&self, input: f32)-> f32 { + let normalised = (input + self.min).log(self.base); + normalised + } + + /// Applies the inverse of the value for the normaliser. + /// + /// # Arguments + /// * `input` - The value to inverse normalise. + /// + /// # Returns + /// The inverse normalised value. + fn inverse_normalise(&self, input: f32) -> f32 { + let denormalised = (input.powf(self.base)) - self.min; + denormalised + } + + /// The key of the normaliser. + /// + /// # Returns + /// The key of the normaliser. + fn key() -> String { + "log_scaling".to_string() + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_normalise_with_both_bounds() { + let normaliser = LogScaling { + base: 10.0, + min: 0.0, + }; + let input = 10.0; + let expected = 1.0; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + +} diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/mod.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/mod.rs new file mode 100644 index 0000000..5b1acbb --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/mod.rs @@ -0,0 +1,235 @@ +//! Defines the loading and saving functionality of normalisers. +use std::collections::HashMap; + +pub mod traits; +pub mod utils; +pub mod linear_scaling; +pub mod clipping; +pub mod log_scale; +pub mod z_score; +pub mod wrapper; + +use super::keys::KeyBindings; +use utils::{extract_label, extract_two_numbers}; +use wrapper::NormaliserType; +use crate::safe_eject_option; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; + + +/// A map of normalisers so they can be accessed by column name and input index. +/// +/// # Fields +/// * `store` - A vector of normalisers. +/// * `store_ref` - A vector of column names to correlate with the normalisers in the store. +/// * `reference` - A map of the index of the column in the key bindings to the index of the normaliser in the store. +#[derive(Debug, PartialEq)] +pub struct NormaliserMap { + pub store: Vec, + pub store_ref: Vec, + pub reference: HashMap, +} + +impl NormaliserMap { + + /// Constructs a new, empty `NormaliserMap`. + /// + /// # Returns + /// A new, empty `NormaliserMap`. + pub fn fresh() -> Self { + NormaliserMap { + store: Vec::new(), + store_ref: Vec::new(), + reference: HashMap::new(), + } + } + + /// Adds a normaliser to the map. + /// + /// # Arguments + /// * `normaliser` - The normaliser to add. + /// * `column_name` - The name of the column to which the normaliser is applied. + /// * `keys_reference` - A reference to the key bindings to extract the index. + pub fn add_normaliser(&mut self, normaliser: NormaliserType, column_name: String, keys_reference: &KeyBindings) -> Result<(), SurrealError> { + let counter = self.store.len(); + let column_input_index = safe_eject_option!(keys_reference.reference.get(column_name.as_str())); + self.reference.insert(column_input_index.clone() as usize, counter as usize); + self.store.push(normaliser); + self.store_ref.push(column_name); + Ok(()) + } + + /// Gets a normaliser from the map. + /// + /// # Arguments + /// * `column_name` - The name of the column to which the normaliser is applied. + /// * `keys_reference` - A reference to the key bindings to extract the index. + /// + /// # Returns + /// The normaliser corresponding to the column name. + pub fn get_normaliser(&self, column_name: String, keys_reference: &KeyBindings) -> Result, SurrealError> { + let column_input_index = safe_eject_option!(keys_reference.reference.get(column_name.as_str())); + let normaliser_index = self.reference.get(column_input_index); + match normaliser_index { + Some(normaliser_index) => Ok(Some(&self.store[*normaliser_index])), + None => Ok(None), + } + } + + /// unpacks the normaliser data from a string. + /// + /// # Arguments + /// * `normaliser_data` - The string containing the normaliser data. + /// + /// # Returns + /// A tuple containing the label (type of normaliser), the numbers and the column name. + pub fn unpack_normaliser_data(normaliser_data: &str) -> Result<(String, [f32; 2], String), SurrealError> { + let mut normaliser_buffer = normaliser_data.split("=>"); + + let column_name = safe_eject_option!(normaliser_buffer.next()); + let normaliser_type = safe_eject_option!(normaliser_buffer.next()).to_string(); + + let label = extract_label(&normaliser_type)?; + let numbers = extract_two_numbers(&normaliser_type)?; + Ok((label, numbers, column_name.to_string())) + } + + /// Constructs a `NormaliserMap` from a string. + /// + /// # Arguments + /// * `data` - The string containing the normaliser data. + /// * `keys_reference` - A reference to the key bindings to extract the index. + /// + /// # Returns + /// A `NormaliserMap` containing the normalisers. + pub fn from_string(data: String, keys_reference: &KeyBindings) -> Result { + if data.len() == 0 { + return Ok(NormaliserMap::fresh()) + } + let normalisers_data = data.split("//"); + let mut counter = 0; + let mut reference = HashMap::new(); + let mut store = Vec::new(); + let mut store_ref = Vec::new(); + + for normaliser_data in normalisers_data { + let (normaliser, column_name) = NormaliserType::from_string(normaliser_data.to_string())?; + let column_input_index = safe_eject_option!(keys_reference.reference.get(column_name.as_str())); + reference.insert(column_input_index.clone() as usize, counter as usize); + store.push(normaliser); + store_ref.push(column_name); + counter += 1; + } + + Ok(NormaliserMap { + reference, + store, + store_ref + }) + } + + /// Converts the `NormaliserMap` to a string. + /// + /// # Returns + /// A string containing the normaliser data. + pub fn to_string(&self) -> String { + let mut buffer = Vec::new(); + + for index in 0..self.store.len() { + let normaliser_string = &self.store[index].to_string(); + buffer.push(format!("{}=>{}", self.store_ref[index], normaliser_string)); + } + + buffer.join("//") + } +} + + +#[cfg(test)] +pub mod tests { + + use super::*; + use super::super::keys::tests::generate_string as generate_key_bindings_string; + use super::super::keys::KeyBindings; + + pub fn generate_string() -> String { + "a=>linear_scaling(0.0,1.0)//b=>clipping(0.0,1.5)//c=>log_scaling(10.0,0.0)//e=>z_score(0.0,1.0)".to_string() + } + + pub fn generate_key_bindings() -> KeyBindings { + let data = generate_key_bindings_string(); + KeyBindings::from_string(data) + } + + #[test] + pub fn test_from_string() { + + let key_bindings = generate_key_bindings(); + + let data = generate_string(); + + let normaliser_map = NormaliserMap::from_string(data, &key_bindings).unwrap(); + + assert_eq!(normaliser_map.reference.len(), 4); + assert_eq!(normaliser_map.store.len(), 4); + + assert_eq!(normaliser_map.reference.get(&0).unwrap(), &0); + assert_eq!(normaliser_map.reference.get(&1).unwrap(), &1); + assert_eq!(normaliser_map.reference.get(&2).unwrap(), &2); + assert_eq!(normaliser_map.reference.get(&4).unwrap(), &3); + } + + #[test] + fn test_to_string() { + let key_bindings = generate_key_bindings(); + let data = generate_string(); + let normaliser_map = NormaliserMap::from_string(data, &key_bindings).unwrap(); + let normaliser_map_string = normaliser_map.to_string(); + + assert_eq!(normaliser_map_string, "a=>linear_scaling(0,1)//b=>clipping(0,1.5)//c=>log_scaling(10,0)//e=>z_score(0,1)"); + } + + #[test] + fn test_add_normalizer() { + + let key_bindings = generate_key_bindings(); + let data = generate_string(); + + let mut normaliser_map = NormaliserMap::from_string(data, &key_bindings).unwrap(); + + let _ = normaliser_map.add_normaliser(NormaliserType::LinearScaling(linear_scaling::LinearScaling{min: 0.0, max: 1.0}), "d".to_string(), &key_bindings); + + assert_eq!(normaliser_map.reference.len(), 5); + assert_eq!(normaliser_map.store.len(), 5); + + assert_eq!(normaliser_map.reference.get(&0).unwrap(), &0); + assert_eq!(normaliser_map.reference.get(&1).unwrap(), &1); + assert_eq!(normaliser_map.reference.get(&2).unwrap(), &2); + assert_eq!(normaliser_map.reference.get(&4).unwrap(), &3); + assert_eq!(normaliser_map.reference.get(&3).unwrap(), &4); + + assert_eq!(normaliser_map.store_ref[0], "a"); + assert_eq!(normaliser_map.store_ref[1], "b"); + assert_eq!(normaliser_map.store_ref[2], "c"); + assert_eq!(normaliser_map.store_ref[3], "e"); + assert_eq!(normaliser_map.store_ref[4], "d"); + } + + #[test] + fn test_get_normaliser() { + let key_bindings = generate_key_bindings(); + let data = generate_string(); + + let normaliser_map = NormaliserMap::from_string(data, &key_bindings).unwrap(); + + let normaliser = normaliser_map.get_normaliser("e".to_string(), &key_bindings).unwrap().unwrap(); + + match normaliser { + NormaliserType::ZScore(z_score) => { + assert_eq!(z_score.mean, 0.0); + assert_eq!(z_score.std_dev, 1.0); + }, + _ => panic!("Wrong normaliser type") + } + } +} + diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/traits.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/traits.rs new file mode 100644 index 0000000..6fa9030 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/traits.rs @@ -0,0 +1,14 @@ +//! traits for the normalisers module. + +pub trait Normaliser { + + /// Normalises a value. + fn normalise(&self, input: f32)-> f32; + + fn inverse_normalise(&self, input: f32)-> f32 { + input + } + + /// Returns the key of the normaliser. + fn key() -> String; +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/utils.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/utils.rs new file mode 100644 index 0000000..21d2b58 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/utils.rs @@ -0,0 +1,66 @@ +//! Utility functions for normalisers to reduce code duplication in areas that cannot be easily defined in a struct. +use regex::{Regex, Captures}; +use crate::{ + safe_eject_option, + safe_eject_internal, +}; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; + + +/// Extracts the label from a normaliser string. +/// +/// # Arguments +/// * `data` - A string containing the normaliser data. +pub fn extract_label(data: &String) -> Result { + let re: Regex = safe_eject_internal!(Regex::new(r"^(.*?)\(")); + let captures: Captures = safe_eject_option!(re.captures(data)); + Ok(safe_eject_option!(captures.get(1)).as_str().to_string()) +} + + +/// Extracts two numbers from a string with brackets where the numbers are in the brackets seperated by comma. +/// +/// # Arguments +/// * `data` - A string containing the normaliser data. +/// +/// # Returns +/// [number1, number2] from `"(number1, number2)"` +pub fn extract_two_numbers(data: &String) -> Result<[f32; 2], SurrealError> { + let re: Regex = safe_eject_internal!(Regex::new(r"[-+]?\d+(\.\d+)?")); + let mut numbers = re.find_iter(data); + let mut buffer: [f32; 2] = [0.0, 0.0]; + + let num_one_str = safe_eject_option!(numbers.next()).as_str(); + let num_two_str = safe_eject_option!(numbers.next()).as_str(); + + buffer[0] = safe_eject_internal!(num_one_str.parse::()); + buffer[1] = safe_eject_internal!(num_two_str.parse::()); + Ok(buffer) +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_extract_two_numbers() { + let data = "linear_scaling(0.0,1.0)".to_string(); + let numbers = extract_two_numbers(&data).unwrap(); + assert_eq!(numbers[0], 0.0); + assert_eq!(numbers[1], 1.0); + + let data = "linear_scaling(0,1)".to_string(); + let numbers = extract_two_numbers(&data).unwrap(); + assert_eq!(numbers[0], 0.0); + assert_eq!(numbers[1], 1.0); + } + + #[test] + fn test_extract_label() { + let data = "linear_scaling(0.0,1.0)".to_string(); + let label = extract_label(&data).unwrap(); + assert_eq!(label, "linear_scaling"); + } +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/wrapper.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/wrapper.rs new file mode 100644 index 0000000..27a2174 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/wrapper.rs @@ -0,0 +1,199 @@ +//! Defines the constructing and storing of normalisers. +use super::linear_scaling; +use super::clipping; +use super::log_scale; +use super::z_score; +use super::utils::{extract_label, extract_two_numbers}; +use super::traits::Normaliser; + +use crate::safe_eject_option; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; + + +/// A wrapper for all different types of normalisers. +/// +/// # Arguments +/// * `LinearScaling` - A linear scaling normaliser. +/// * `Clipping` - A clipping normaliser. +/// * `LogScaling` - A log scaling normaliser. +/// * `ZScore` - A z-score normaliser. +#[derive(Debug, PartialEq)] +pub enum NormaliserType { + LinearScaling(linear_scaling::LinearScaling), + Clipping(clipping::Clipping), + LogScaling(log_scale::LogScaling), + ZScore(z_score::ZScore), +} + + +impl NormaliserType { + + /// Constructs a new normaliser. + /// + /// # Arguments + /// * `label` - The label of the normaliser. + /// * `one` - The first parameter of the normaliser. + /// * `two` - The second parameter of the normaliser. + /// + /// # Returns + /// A new normaliser. + pub fn new(label: String, one: f32, two: f32) -> Self { + match label.as_str() { + "linear_scaling" => NormaliserType::LinearScaling(linear_scaling::LinearScaling{min: one, max: two}), + "clipping" => NormaliserType::Clipping(clipping::Clipping{min: Some(one), max: Some(two)}), + "log_scaling" => NormaliserType::LogScaling(log_scale::LogScaling{base: one, min: two}), + "z_score" => NormaliserType::ZScore(z_score::ZScore{mean: one, std_dev: two}), + _ => panic!("Invalid normaliser label: {}", label), + } + } + + /// Unpacks a normaliser from a string. + /// + /// # Arguments + /// * `normaliser_data` - A string containing the normaliser data. + /// + /// # Returns + /// (type of normaliser, [normaliser parameters], column name) + pub fn unpack_normaliser_data(normaliser_data: &str) -> Result<(String, [f32; 2], String), SurrealError> { + let mut normaliser_buffer = normaliser_data.split("=>"); + + let column_name = safe_eject_option!(normaliser_buffer.next()); + let normaliser_type = safe_eject_option!(normaliser_buffer.next()).to_string(); + + let label = extract_label(&normaliser_type)?; + let numbers = extract_two_numbers(&normaliser_type)?; + Ok((label, numbers, column_name.to_string())) + } + + /// Constructs a normaliser from a string. + /// + /// # Arguments + /// * `data` - A string containing the normaliser data. + /// + /// # Returns + /// (normaliser, column name) + pub fn from_string(data: String) -> Result<(Self, String), SurrealError> { + let (label, numbers, column_name) = Self::unpack_normaliser_data(&data)?; + let normaliser = match label.as_str() { + "linear_scaling" => { + let min = numbers[0]; + let max = numbers[1]; + NormaliserType::LinearScaling(linear_scaling::LinearScaling{min, max}) + }, + "clipping" => { + let min = numbers[0]; + let max = numbers[1]; + NormaliserType::Clipping(clipping::Clipping{min: Some(min), max: Some(max)}) + }, + "log_scaling" => { + let base = numbers[0]; + let min = numbers[1]; + NormaliserType::LogScaling(log_scale::LogScaling{base, min}) + }, + "z_score" => { + let mean = numbers[0]; + let std_dev = numbers[1]; + NormaliserType::ZScore(z_score::ZScore{mean, std_dev}) + }, + _ => { + let error = SurrealError::new( + format!("Unknown normaliser type: {}", label).to_string(), + SurrealErrorStatus::Unknown + ); + return Err(error) + } + }; + Ok((normaliser, column_name)) + } + + /// Converts a normaliser to a string. + /// + /// # Returns + /// A string containing the normaliser data. + pub fn to_string(&self) -> String { + let normaliser_string = match self { + NormaliserType::LinearScaling(linear_scaling) => { + let min = linear_scaling.min; + let max = linear_scaling.max; + format!("linear_scaling({},{})", min, max) + }, + NormaliserType::Clipping(clipping) => { + let min = clipping.min.unwrap(); + let max = clipping.max.unwrap(); + format!("clipping({},{})", min, max) + }, + NormaliserType::LogScaling(log_scaling) => { + let base = log_scaling.base; + let min = log_scaling.min; + format!("log_scaling({},{})", base, min) + }, + NormaliserType::ZScore(z_score) => { + let mean = z_score.mean; + let std_dev = z_score.std_dev; + format!("z_score({},{})", mean, std_dev) + }, + }; + normaliser_string + } + + /// Normalises a value. + /// + /// # Arguments + /// * `value` - The value to normalise. + /// + /// # Returns + /// The normalised value. + pub fn normalise(&self, value: f32) -> f32 { + match self { + NormaliserType::LinearScaling(normaliser) => normaliser.normalise(value), + NormaliserType::Clipping(normaliser) => normaliser.normalise(value), + NormaliserType::LogScaling(normaliser) => normaliser.normalise(value), + NormaliserType::ZScore(normaliser) => normaliser.normalise(value), + } + } + + /// Inverse normalises a value. + /// + /// # Arguments + /// * `value` - The value to inverse normalise. + /// + /// # Returns + /// The inverse normalised value. + pub fn inverse_normalise(&self, value: f32) -> f32 { + match self { + NormaliserType::LinearScaling(normaliser) => normaliser.inverse_normalise(value), + NormaliserType::Clipping(normaliser) => normaliser.inverse_normalise(value), + NormaliserType::LogScaling(normaliser) => normaliser.inverse_normalise(value), + NormaliserType::ZScore(normaliser) => normaliser.inverse_normalise(value), + } + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + pub fn generate_string() -> String { + let normaliser = NormaliserType::LinearScaling(linear_scaling::LinearScaling{min: 0.0, max: 1.0}); + let column_name = "column_name".to_string(); + format!("{}=>{}", column_name, normaliser.to_string()) + } + + #[test] + fn test_normaliser_type_to_string() { + let normaliser = NormaliserType::LinearScaling(linear_scaling::LinearScaling{min: 0.0, max: 1.0}); + assert_eq!(normaliser.to_string(), "linear_scaling(0,1)"); + } + + #[test] + fn test_normaliser_type_from_string() { + let normaliser_string = generate_string(); + let (normaliser, column_name) = NormaliserType::from_string(normaliser_string).unwrap(); + assert_eq!(normaliser, NormaliserType::LinearScaling(linear_scaling::LinearScaling{min: 0.0, max: 1.0})); + assert_eq!(column_name, "column_name"); + } + +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/normalisers/z_score.rs b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/z_score.rs new file mode 100644 index 0000000..964844f --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/normalisers/z_score.rs @@ -0,0 +1,67 @@ +//! The functionality and parameters around a z-score normaliser. +use super::traits::Normaliser; + + +/// A z-score normaliser. +/// +/// # Fields +/// * `mean` - The mean of the normaliser. +/// * `std_dev` - The standard deviation of the normaliser. +#[derive(Debug, PartialEq)] +pub struct ZScore { + pub mean: f32, + pub std_dev: f32, +} + + +impl Normaliser for ZScore { + + /// Normalises a value. + /// + /// # Arguments + /// * `input` - The value to normalise. + /// + /// # Returns + /// The normalised value. + fn normalise(&self, input: f32)-> f32 { + let normalised = (input - self.mean) / self.std_dev; + normalised + } + + /// Applies the inverse of the value for the normaliser. + /// + /// # Arguments + /// * `input` - The value to inverse normalise. + /// + /// # Returns + /// The inverse normalised value. + fn inverse_normalise(&self, input: f32)-> f32 { + let denormalised = (input * self.std_dev) + self.mean; + denormalised + } + + fn key() -> String { + "z_score".to_string() + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_normalise_with_both_bounds() { + let normaliser = ZScore { + mean: 0.0, + std_dev: 1.0, + }; + let input = 0.0; + let expected = 0.0; + let actual = normaliser.normalise(input); + assert_eq!(expected, actual); + } + +} diff --git a/modules/c-wrapper/build-context/core/src/storage/header/origin.rs b/modules/c-wrapper/build-context/core/src/storage/header/origin.rs new file mode 100644 index 0000000..8571f15 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/origin.rs @@ -0,0 +1,181 @@ +//! Defines the origin of the model in the file. +use crate::errors::error::{SurrealError, SurrealErrorStatus}; + +use super::string_value::StringValue; + + +const LOCAL: &str = "local"; +const SURREAL_DB: &str = "surreal_db"; +const NONE: &str = ""; + + +/// Defines the types of origin that are supported. +/// +/// # Fields +/// * `Local` - The model was created locally. +/// * `SurrealDb` - The model was created in the surreal database. +/// * `None` - The model has no origin +#[derive(Debug, PartialEq)] +pub enum OriginValue { + Local(StringValue), + SurrealDb(StringValue), + None(StringValue), +} + +impl OriginValue { + + /// Creates a new `OriginValue` with no value. + /// + /// # Returns + /// A new `OriginValue` with no value. + pub fn fresh() -> Self { + OriginValue::None(StringValue::fresh()) + } + + /// Create a `OriginValue` from a string. + /// + /// # Arguments + /// * `origin` - The origin as a string. + /// + /// # Returns + /// A new `OriginValue`. + pub fn from_string(origin: String) -> Result { + match origin.to_lowercase().as_str() { + LOCAL => Ok(OriginValue::Local(StringValue::from_string(origin))), + SURREAL_DB => Ok(OriginValue::SurrealDb(StringValue::from_string(origin))), + NONE => Ok(OriginValue::None(StringValue::from_string(origin))), + _ => Err(SurrealError::new(format!("invalid origin: {}", origin), SurrealErrorStatus::BadRequest)) + } + } + + /// Converts the `OriginValue` to a string. + /// + /// # Returns + /// The `OriginValue` as a string. + pub fn to_string(&self) -> String { + match self { + OriginValue::Local(string_value) => string_value.to_string(), + OriginValue::SurrealDb(string_value) => string_value.to_string(), + OriginValue::None(string_value) => string_value.to_string(), + } + } + +} + + +/// Defines the origin of the model in the file header. +/// +/// # Fields +/// * `origin` - The origin of the model. +/// * `author` - The author of the model. +#[derive(Debug, PartialEq)] +pub struct Origin { + pub origin: OriginValue, + pub author: StringValue, +} + + +impl Origin { + + /// Creates a new origin with no values. + /// + /// # Returns + /// A new origin with no values. + pub fn fresh() -> Self { + Origin { + origin: OriginValue::fresh(), + author: StringValue::fresh(), + } + } + + /// Adds an author to the origin struct. + /// + /// # Arguments + /// * `origin` - The origin to be added. + pub fn add_author(&mut self, author: String) { + self.author = StringValue::from_string(author); + } + + /// Adds an origin to the origin struct. + /// + /// # Arguments + pub fn add_origin(&mut self, origin: String) -> Result<(), SurrealError> { + self.origin = OriginValue::from_string(origin)?; + Ok(()) + } + + /// Converts an origin to a string. + /// + /// # Returns + /// The origin as a string. + pub fn to_string(&self) -> String { + if self.author.value.is_none() && self.origin == OriginValue::None(StringValue::fresh()) { + return String::from("") + } + format!("{}=>{}", self.author.to_string(), self.origin.to_string()) + } + + /// Creates a new origin from a string. + /// + /// # Arguments + /// * `origin` - The origin as a string. + /// + /// # Returns + /// A new origin. + pub fn from_string(origin: String) -> Result { + if origin == "".to_string() { + return Ok(Origin::fresh()); + } + let mut split = origin.split("=>"); + let author = split.next().unwrap().to_string(); + let origin = split.next().unwrap().to_string(); + Ok(Origin { + origin: OriginValue::from_string(origin)?, + author: StringValue::from_string(author), + }) + } + +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_fresh() { + let origin = Origin::fresh(); + assert_eq!(origin, Origin { + origin: OriginValue::fresh(), + author: StringValue::fresh(), + }); + } + + #[test] + fn test_to_string() { + let origin = Origin { + origin: OriginValue::from_string("local".to_string()).unwrap(), + author: StringValue::from_string("author".to_string()), + }; + assert_eq!(origin.to_string(), "author=>local".to_string()); + + let origin = Origin::fresh(); + assert_eq!(origin.to_string(), "".to_string()); + } + + #[test] + fn test_from_string() { + let origin = Origin::from_string("author=>local".to_string()).unwrap(); + assert_eq!(origin, Origin { + origin: OriginValue::from_string("local".to_string()).unwrap(), + author: StringValue::from_string("author".to_string()), + }); + + let origin = Origin::from_string("=>local".to_string()).unwrap(); + + assert_eq!(None, origin.author.value); + assert_eq!("local".to_string(), origin.origin.to_string()); + } + +} diff --git a/modules/c-wrapper/build-context/core/src/storage/header/output.rs b/modules/c-wrapper/build-context/core/src/storage/header/output.rs new file mode 100644 index 0000000..19cf560 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/output.rs @@ -0,0 +1,172 @@ +//! Defines the struct housing data around the outputs of the model. +use super::normalisers::wrapper::NormaliserType; +use crate::{ + safe_eject_option, + errors::error::{ + SurrealError, + SurrealErrorStatus + } +}; + + +/// Houses data around the outputs of the model. +/// +/// # Fields +/// * `name` - The name of the output. +/// * `normaliser` - The normaliser to be applied to the output if there is one. +#[derive(Debug, PartialEq)] +pub struct Output { + pub name: Option, + pub normaliser: Option, +} + +impl Output { + + /// Creates a new instance of the Output struct with no normaliser or name. + /// + /// # Returns + /// A new instance of the Output struct with no normaliser or name. + pub fn fresh() -> Self { + Output { + name: None, + normaliser: None, + } + } + + /// Creates a new instance of the Output struct without a normaliser. + /// + /// # Arguments + /// * `name` - The name of the output. + pub fn new(name: String) -> Self { + Output { + name: Some(name), + normaliser: None, + } + } + + /// Adds a normaliser to the output. + /// + /// # Arguments + /// * `normaliser` - The normaliser to be applied to the output. + pub fn add_normaliser(&mut self, normaliser: NormaliserType) { + self.normaliser = Some(normaliser); + } + + /// Converts the output struct to a string. + /// + /// # Returns + /// * `String` - The output struct as a string. + pub fn to_string(&self) -> String { + + if &self.name == &None && &self.normaliser == &None { + return "".to_string(); + } + + let name = match &self.name { + Some(name) => name.clone(), + None => "none".to_string(), + }; + let mut buffer = vec![ + name.clone(), + ]; + match &self.normaliser { + Some(normaliser) => buffer.push(normaliser.to_string()), + None => buffer.push("none".to_string()), + } + buffer.join("=>") + } + + /// Converts a string to an instance of the Output struct. + /// + /// # Arguments + /// * `data` - The string to be converted into an instance of the Output struct. + /// + /// # Returns + /// * `Output` - The string as an instance of the Output struct. + pub fn from_string(data: String) -> Result { + if data.contains("=>") == false { + return Ok(Output::fresh()) + } + let mut buffer = data.split("=>"); + + let name = safe_eject_option!(buffer.next()); + let name = match name { + "none" => None, + _ => Some(name.to_string()), + }; + + let normaliser = safe_eject_option!(buffer.next()); + let normaliser = match normaliser { + "none" => None, + _ => Some(NormaliserType::from_string(data).unwrap().0), + }; + return Ok(Output { + name, + normaliser + }) + } +} + + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + fn test_output_to_string() { + + // with no normaliser + let mut output = Output::new("test".to_string()); + assert_eq!(output.to_string(), "test=>none"); + + let normaliser_data = "a=>linear_scaling(0.0,1.0)".to_string(); + let normaliser = NormaliserType::from_string(normaliser_data).unwrap(); + + output.add_normaliser(normaliser.0); + assert_eq!(output.to_string(), "test=>linear_scaling(0,1)"); + } + + #[test] + fn test_from_string() { + let data = "test=>linear_scaling(0,1)".to_string(); + let output = Output::from_string(data).unwrap(); + + assert_eq!(output.name.unwrap(), "test"); + assert_eq!(output.normaliser.is_some(), true); + assert_eq!(output.normaliser.unwrap().to_string(), "linear_scaling(0,1)"); + } + + #[test] + fn test_from_string_with_no_normaliser() { + let data = "test=>none".to_string(); + let output = Output::from_string(data).unwrap(); + + assert_eq!(output.name.unwrap(), "test"); + assert_eq!(output.normaliser.is_none(), true); + } + + #[test] + fn test_from_string_with_no_name() { + let data = "none=>none".to_string(); + let output = Output::from_string(data).unwrap(); + + assert_eq!(output.name.is_none(), true); + assert_eq!(output.normaliser.is_none(), true); + } + + #[test] + fn test_from_string_with_empty_string() { + let data = "".to_string(); + let output = Output::from_string(data).unwrap(); + + assert_eq!(output.name.is_none(), true); + assert_eq!(output.normaliser.is_none(), true); + } + + #[test] + fn test_to_string_with_no_data() { + let output = Output::fresh(); + assert_eq!(output.to_string(), ""); + } +} diff --git a/modules/c-wrapper/build-context/core/src/storage/header/string_value.rs b/modules/c-wrapper/build-context/core/src/storage/header/string_value.rs new file mode 100644 index 0000000..e967624 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/string_value.rs @@ -0,0 +1,99 @@ +//! Defines a generic string value for the header. + + +/// Defines a generic string value for the header. +/// +/// # Fields +/// * `value` - The value of the string. +#[derive(Debug, PartialEq)] +pub struct StringValue { + pub value: Option, +} + + +impl StringValue { + + /// Creates a new string value with no value. + /// + /// # Returns + /// A new string value with no value. + pub fn fresh() -> Self { + StringValue { + value: None, + } + } + + /// Creates a new string value from a string. + /// + /// # Arguments + /// * `value` - The value of the string. + /// + /// # Returns + /// A new string value. + pub fn from_string(value: String) -> Self { + match value.as_str() { + "" => StringValue { + value: None, + }, + _ => StringValue { + value: Some(value), + }, + } + } + + /// Converts the string value to a string. + /// + /// # Returns + /// The string value as a string. + pub fn to_string(&self) -> String { + match &self.value { + Some(value) => value.to_string(), + None => String::from(""), + } + } + +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fresh() { + let string_value = StringValue::fresh(); + assert_eq!(string_value, StringValue { + value: None, + }); + } + + #[test] + fn test_from_string() { + let string_value = StringValue::from_string(String::from("test")); + assert_eq!(string_value, StringValue { + value: Some(String::from("test")), + }); + } + + #[test] + fn test_from_string_none() { + let string_value = StringValue::from_string(String::from("")); + assert_eq!(string_value, StringValue { + value: None, + }); + } + + #[test] + fn test_to_string() { + let string_value = StringValue::from_string(String::from("test")); + assert_eq!(string_value.to_string(), String::from("test")); + } + + #[test] + fn test_to_string_none() { + let string_value = StringValue { + value: None, + }; + assert_eq!(string_value.to_string(), String::from("")); + } +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/header/version.rs b/modules/c-wrapper/build-context/core/src/storage/header/version.rs new file mode 100644 index 0000000..cd963f4 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/header/version.rs @@ -0,0 +1,159 @@ +//! Defines the process of managing the version of the `surml` file in the file. +use crate::{ + safe_eject_option, + safe_eject, + errors::error::{ + SurrealError, + SurrealErrorStatus + } +}; + + +/// The `Version` struct represents the version of the `surml` file. +/// +/// # Fields +/// * `one` - The first number in the version. +/// * `two` - The second number in the version. +/// * `three` - The third number in the version. +#[derive(Debug, PartialEq)] +pub struct Version { + pub one: u8, + pub two: u8, + pub three: u8, +} + + +impl Version { + + /// Creates a new `Version` struct with all zeros. + /// + /// # Returns + /// A new `Version` struct with all zeros. + pub fn fresh() -> Self { + Version { + one: 0, + two: 0, + three: 0, + } + } + + /// Translates the struct to a string. + /// + /// # Returns + /// * `String` - The struct as a string. + pub fn to_string(&self) -> String { + if self.one == 0 && self.two == 0 && self.three == 0 { + return "".to_string(); + } + format!("{}.{}.{}", self.one, self.two, self.three) + } + + /// Creates a new `Version` struct from a string. + /// + /// # Arguments + /// * `version` - The version as a string. + /// + /// # Returns + /// A new `Version` struct. + pub fn from_string(version: String) -> Result { + if version == "".to_string() { + return Ok(Version::fresh()) + } + let mut split = version.split("."); + let one_str = safe_eject_option!(split.next()); + let two_str = safe_eject_option!(split.next()); + let three_str = safe_eject_option!(split.next()); + + Ok(Version { + one: safe_eject!(one_str.parse::(), SurrealErrorStatus::BadRequest), + two: safe_eject!(two_str.parse::(), SurrealErrorStatus::BadRequest), + three: safe_eject!(three_str.parse::(), SurrealErrorStatus::BadRequest), + }) + } + + /// Increments the version by one. + pub fn increment(&mut self) { + self.three += 1; + if self.three == 10 { + self.three = 0; + self.two += 1; + if self.two == 10 { + self.two = 0; + self.one += 1; + } + } + } +} + + +#[cfg(test)] +pub mod tests { + + use super::*; + + #[test] + fn test_from_string() { + let version = Version::from_string("0.0.0".to_string()).unwrap(); + assert_eq!(version.one, 0); + assert_eq!(version.two, 0); + assert_eq!(version.three, 0); + + let version = Version::from_string("1.2.3".to_string()).unwrap(); + assert_eq!(version.one, 1); + assert_eq!(version.two, 2); + assert_eq!(version.three, 3); + } + + #[test] + fn test_to_string() { + let version = Version{ + one: 0, + two: 0, + three: 0, + }; + assert_eq!(version.to_string(), ""); + + let version = Version{ + one: 1, + two: 2, + three: 3, + }; + assert_eq!(version.to_string(), "1.2.3"); + } + + #[test] + fn test_increment() { + let mut version = Version{ + one: 0, + two: 0, + three: 0, + }; + version.increment(); + assert_eq!(version.to_string(), "0.0.1"); + + let mut version = Version{ + one: 0, + two: 0, + three: 9, + }; + version.increment(); + assert_eq!(version.to_string(), "0.1.0"); + + let mut version = Version{ + one: 0, + two: 9, + three: 9, + }; + version.increment(); + assert_eq!(version.to_string(), "1.0.0"); + + let mut version = Version{ + one: 9, + two: 9, + three: 9, + }; + version.increment(); + assert_eq!(version.to_string(), "10.0.0"); + } + +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/src/storage/mod.rs b/modules/c-wrapper/build-context/core/src/storage/mod.rs new file mode 100644 index 0000000..0b70a87 --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/mod.rs @@ -0,0 +1,4 @@ +//! Responsible for the saving and loading of the model including meta data around the model. +pub mod header; +pub mod surml_file; +pub mod stream_adapter; diff --git a/modules/c-wrapper/build-context/core/src/storage/stream_adapter.rs b/modules/c-wrapper/build-context/core/src/storage/stream_adapter.rs new file mode 100644 index 0000000..a5ad91a --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/stream_adapter.rs @@ -0,0 +1,71 @@ +//! Stream adapter for file system +use std::fs::File; +use std::io::Read; +use bytes::Bytes; + +use futures_core::stream::Stream; +use futures_core::task::{Context, Poll}; +use std::pin::Pin; +use std::error::Error; +use crate::{ + safe_eject, + errors::error::{ + SurrealError, + SurrealErrorStatus + } +}; + + +/// Stream adapter for file system. +/// +/// # Arguments +/// * `chunk_size` - The size of the chunks to read from the file. +/// * `file_pointer` - The pointer to the file to be streamed +pub struct StreamAdapter { + chunk_size: usize, + file_pointer: File +} + +impl StreamAdapter { + + /// Creates a new `StreamAdapter` struct. + /// + /// # Arguments + /// * `chunk_size` - The size of the chunks to read from the file. + /// * `file_path` - The path to the file to be streamed + /// + /// # Returns + /// A new `StreamAdapter` struct. + pub fn new(chunk_size: usize, file_path: String) -> Result { + let file_pointer = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound); + Ok(StreamAdapter { + chunk_size, + file_pointer + }) + } + +} + +impl Stream for StreamAdapter { + + type Item = Result>; + + /// Polls the next chunk from the file. + /// + /// # Arguments + /// * `self` - The `StreamAdapter` struct. + /// * `cx` - The context of the task to enable the task to be woken up and polled again using the waker. + /// + /// # Returns + /// A poll containing the next chunk from the file. + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let mut buffer = vec![0u8; self.chunk_size]; + let bytes_read = self.file_pointer.read(&mut buffer)?; + + buffer.truncate(bytes_read); + if buffer.is_empty() { + return Poll::Ready(None); + } + Poll::Ready(Some(Ok(buffer.into()))) + } +} diff --git a/modules/c-wrapper/build-context/core/src/storage/surml_file.rs b/modules/c-wrapper/build-context/core/src/storage/surml_file.rs new file mode 100644 index 0000000..aa056ca --- /dev/null +++ b/modules/c-wrapper/build-context/core/src/storage/surml_file.rs @@ -0,0 +1,227 @@ +//! Defines the saving and loading of the entire `surml` file. +use std::fs::File; +use std::io::{Read, Write}; + +use crate::{ + safe_eject_internal, + safe_eject, + storage::header::Header, + errors::error::{ + SurrealError, + SurrealErrorStatus + } +}; + + +/// The `SurMlFile` struct represents the entire `surml` file. +/// +/// # Fields +/// * `header` - The header of the `surml` file containing data such as key bindings for inputs and normalisers. +/// * `model` - The PyTorch model in C. +pub struct SurMlFile { + pub header: Header, + pub model: Vec, +} + + +impl SurMlFile { + + /// Creates a new `SurMlFile` struct with an empty header. + /// + /// # Arguments + /// * `model` - The PyTorch model in C. + /// + /// # Returns + /// A new `SurMlFile` struct with no columns or normalisers. + pub fn fresh(model: Vec) -> Self { + Self { + header: Header::fresh(), + model + } + } + + /// Creates a new `SurMlFile` struct. + /// + /// # Arguments + /// * `header` - The header of the `surml` file containing data such as key bindings for inputs and normalisers. + /// * `model` - The PyTorch model in C. + /// + /// # Returns + /// A new `SurMlFile` struct. + pub fn new(header: Header, model: Vec) -> Self { + Self { + header, + model, + } + } + + /// Creates a new `SurMlFile` struct from a vector of bytes. + /// + /// # Arguments + /// * `bytes` - A vector of bytes representing the header and the model. + /// + /// # Returns + /// A new `SurMlFile` struct. + pub fn from_bytes(bytes: Vec) -> Result { + // check to see if there is enough bytes to read + if bytes.len() < 4 { + return Err( + SurrealError::new( + "Not enough bytes to read".to_string(), + SurrealErrorStatus::BadRequest + ) + ); + } + let mut header_bytes = Vec::new(); + let mut model_bytes = Vec::new(); + + // extract the first 4 bytes as an integer to get the length of the header + let mut buffer = [0u8; 4]; + buffer.copy_from_slice(&bytes[0..4]); + let integer_value = u32::from_be_bytes(buffer); + + // check to see if there is enough bytes to read + if bytes.len() < (4 + integer_value as usize) { + return Err( + SurrealError::new( + "Not enough bytes to read for header, maybe the file format is not correct".to_string(), + SurrealErrorStatus::BadRequest + ) + ); + } + + // Read the next integer_value bytes for the header + header_bytes.extend_from_slice(&bytes[4..(4 + integer_value as usize)]); + + // Read the remaining bytes for the model + model_bytes.extend_from_slice(&bytes[(4 + integer_value as usize)..]); + + // construct the header and C model from the bytes + let header = Header::from_bytes(header_bytes)?; + let model = model_bytes; + Ok(Self { + header, + model, + }) + } + + /// Creates a new `SurMlFile` struct from a file. + /// + /// # Arguments + /// * `file_path` - The path to the `surml` file. + /// + /// # Returns + /// A new `SurMlFile` struct. + pub fn from_file(file_path: &str) -> Result { + let mut file = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound); + + // extract the first 4 bytes as an integer to get the length of the header + let mut buffer = [0u8; 4]; + safe_eject!(file.read_exact(&mut buffer), SurrealErrorStatus::BadRequest); + let integer_value = u32::from_be_bytes(buffer); + + // Read the next integer_value bytes for the header + let mut header_buffer = vec![0u8; integer_value as usize]; + safe_eject!(file.read_exact(&mut header_buffer), SurrealErrorStatus::BadRequest); + + // Create a Vec to store the data + let mut model_buffer = Vec::new(); + + // Read the rest of the file into the buffer + safe_eject!(file.take(usize::MAX as u64).read_to_end(&mut model_buffer), SurrealErrorStatus::BadRequest); + + // construct the header and C model from the bytes + let header = Header::from_bytes(header_buffer)?; + Ok(Self { + header, + model: model_buffer, + }) + } + + /// Converts the header and the model to a vector of bytes. + /// + /// # Returns + /// A vector of bytes representing the header and the model. + pub fn to_bytes(&self) -> Vec { + // compile the header into bytes. + let (num, header_bytes) = self.header.to_bytes(); + let num_bytes = i32::to_be_bytes(num).to_vec(); + + // combine the bytes into a single vector + let mut combined_vec: Vec = Vec::new(); + combined_vec.extend(num_bytes); + combined_vec.extend(header_bytes); + combined_vec.extend(self.model.clone()); + return combined_vec + } + + /// Writes the header and the model to a `surml` file. + /// + /// # Arguments + /// * `file_path` - The path to the `surml` file. + /// + /// # Returns + /// An `io::Result` indicating whether the write was successful. + pub fn write(&self, file_path: &str) -> Result<(), SurrealError> { + let combined_vec = self.to_bytes(); + + // write the bytes to a file + let mut file = safe_eject_internal!(File::create(file_path)); + safe_eject_internal!(file.write(&combined_vec)); + Ok(()) + } +} + + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_write() { + let mut header = Header::fresh(); + header.add_column(String::from("squarefoot")); + header.add_column(String::from("num_floors")); + header.add_output(String::from("house_price"), None); + + let mut file = File::open("./stash/linear_test.onnx").unwrap(); + + let mut model_bytes = Vec::new(); + file.read_to_end(&mut model_bytes).unwrap(); + + let surml_file = SurMlFile::new(header, model_bytes); + surml_file.write("./stash/test.surml").unwrap(); + + let _ = SurMlFile::from_file("./stash/test.surml").unwrap(); + } + + #[test] + fn test_write_forrest() { + + let header = Header::fresh(); + + let mut file = File::open("./stash/forrest_test.onnx").unwrap(); + + let mut model_bytes = Vec::new(); + file.read_to_end(&mut model_bytes).unwrap(); + + let surml_file = SurMlFile::new(header, model_bytes); + surml_file.write("./stash/forrest.surml").unwrap(); + + let _ = SurMlFile::from_file("./stash/forrest.surml").unwrap(); + + } + + #[test] + fn test_empty_buffer() { + let bytes = vec![0u8; 0]; + match SurMlFile::from_bytes(bytes) { + Ok(_) => assert!(false), + Err(error) => { + assert_eq!(error.status, SurrealErrorStatus::BadRequest); + assert_eq!(error.to_string(), "Not enough bytes to read"); + } + } + } +} \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/stash/forrest.surml b/modules/c-wrapper/build-context/core/stash/forrest.surml new file mode 100644 index 0000000..41ee70c Binary files /dev/null and b/modules/c-wrapper/build-context/core/stash/forrest.surml differ diff --git a/modules/c-wrapper/build-context/core/stash/forrest_test.onnx b/modules/c-wrapper/build-context/core/stash/forrest_test.onnx new file mode 100644 index 0000000..4fa390d Binary files /dev/null and b/modules/c-wrapper/build-context/core/stash/forrest_test.onnx differ diff --git a/modules/c-wrapper/build-context/core/stash/linear_test.onnx b/modules/c-wrapper/build-context/core/stash/linear_test.onnx new file mode 100644 index 0000000..f7b070b --- /dev/null +++ b/modules/c-wrapper/build-context/core/stash/linear_test.onnx @@ -0,0 +1,15 @@ +pytorch2.0.1:… +Q +onnx::MatMul_0 +onnx::MatMul_6/linear/MatMul_output_0/linear/MatMul"MatMul +; + linear.bias +/linear/MatMul_output_05 /linear/Add"Add torch_jit*B linear.biasJdÕ ²* Bonnx::MatMul_6J ‘9?[ÄŒ>Z +onnx::MatMul_0 + + +b +5 + + +B \ No newline at end of file diff --git a/modules/c-wrapper/build-context/core/stash/test.surml b/modules/c-wrapper/build-context/core/stash/test.surml new file mode 100644 index 0000000..61da29a Binary files /dev/null and b/modules/c-wrapper/build-context/core/stash/test.surml differ diff --git a/modules/c-wrapper/build.rs b/modules/c-wrapper/build.rs new file mode 100644 index 0000000..bc0e78e --- /dev/null +++ b/modules/c-wrapper/build.rs @@ -0,0 +1,121 @@ +use std::env; +use std::fs; +use std::io::prelude::*; +use std::fs::File; +use std::io::Cursor; +use std::path::Path; +use reqwest::blocking::get; +use flate2::read::GzDecoder; +use tar::Archive; + +fn main() { + let version = "1.20.1"; + let root_dir_str = env::var("OUT_DIR").unwrap(); + let root_dir = Path::new(&root_dir_str); + let current_working_dir = std::env::current_dir().unwrap(); + let out_dir = std::env::current_dir().unwrap().join("onnx_lib"); + + // Create output directory + fs::create_dir_all(&out_dir).expect("Failed to create output directory"); + + // Detect OS and architecture + let target_os = env::var("CARGO_CFG_TARGET_OS").expect("Failed to get target OS"); + let target_arch = env::var("CARGO_CFG_TARGET_ARCH").expect("Failed to get target architecture"); + let target_env = env::var("CARGO_CFG_TARGET_ENV").unwrap_or_default(); // Optional: For specific environments like MSVC + + // Map to appropriate URL + let file_extension = match target_os.as_str() { + "windows" => "zip", + _ => "tgz", + }; + + // Construct the directory name + // linux aarch64 + let directory_name = match (target_os.as_str(), target_arch.as_str()) { + ("linux", "aarch64") => format!("onnxruntime-linux-aarch64-{version}"), + ("linux", "x86_64") => format!("onnxruntime-linux-x64-{version}"), + ("macos", "aarch64") => format!("onnxruntime-osx-arm64-{version}"), + ("macos", "x86_64") => format!("onnxruntime-osx-x86_64-{version}"), + ("windows", "aarch64") => format!("onnxruntime-win-arm64-{version}"), + ("windows", "x86_64") => format!("onnxruntime-win-x64-{version}"), + ("windows", "x86") => format!("onnxruntime-win-x86-{version}"), + _ => panic!("Unsupported OS/architecture combination"), + }; + println!("build directory defined: {}", directory_name); + + let filename = match (target_os.as_str(), target_arch.as_str(), target_env.as_str()) { + ("linux", "aarch64", _) => format!("onnxruntime-linux-aarch64-{version}.{file_extension}"), + ("linux", "x86_64", _) => { + if cfg!(feature = "gpu") { + format!("onnxruntime-linux-x64-gpu-{version}.{file_extension}") + } else { + format!("onnxruntime-linux-x64-{version}.{file_extension}") + } + } + ("macos", "aarch64", _) => format!("onnxruntime-osx-arm64-{version}.{file_extension}"), + ("macos", "x86_64", _) => format!("onnxruntime-osx-x86_64-{version}.{file_extension}"), + ("windows", "x86_64", _) => { + if cfg!(feature = "gpu") { + format!("onnxruntime-win-x64-gpu-{version}.{file_extension}") + } else { + format!("onnxruntime-win-x64-{version}.{file_extension}") + } + } + ("windows", "x86", _) => format!("onnxruntime-win-x86-{version}.{file_extension}"), + ("windows", "aarch64", _) => format!("onnxruntime-win-arm64-{version}.{file_extension}"), + _ => panic!("Unsupported OS/architecture combination"), + }; + println!("build filename defined: {}", filename); + + let url = format!( + "https://github.com/microsoft/onnxruntime/releases/download/v{version}/{filename}" + ); + + // Download and extract + println!("Downloading ONNX Runtime from {}", url); + let response = get(&url).expect("Failed to send request"); + if !response.status().is_success() { + panic!("Failed to download ONNX Runtime: HTTP {}", response.status()); + } + println!("Downloaded ONNX Runtime successfully"); + + if file_extension == "tgz" { + let tar_gz = GzDecoder::new(Cursor::new(response.bytes().expect("Failed to read response bytes"))); + let mut archive = Archive::new(tar_gz); + archive.unpack(&out_dir).expect("Failed to extract archive"); + } else if file_extension == "zip" { + let mut archive = zip::ZipArchive::new(Cursor::new( + response.bytes().expect("Failed to read response bytes"), + )) + .expect("Failed to open ZIP archive"); + archive.extract(&out_dir).expect("Failed to extract ZIP archive"); + } + println!("Extracted ONNX Runtime successfully"); + + let lib_filename = match target_os.as_str() { + "windows" => "onnxruntime.dll", + "macos" => "libonnxruntime.dylib", + _ => "libonnxruntime.so", + }; + println!("lib filename defined: {}", lib_filename); + + let output_dir = Path::new(&out_dir); + let lib_path = output_dir.join(directory_name.clone()).join("lib").join(lib_filename); + + // copy the library to the output directory + fs::copy(&lib_path, Path::new(&out_dir).join("onnxruntime")).expect("Failed to copy library"); + + let path_data = format!("Copied library to output directory {} -> {}", lib_path.display(), out_dir.display()); + + let mut file = File::create(current_working_dir.join("build_output.txt")).unwrap(); + file.write_all(path_data.as_bytes()).unwrap(); + + println!("{}", path_data); + // remove the out_dir + fs::remove_dir_all(&output_dir.join(directory_name)).expect("Failed to remove output directory"); + + let output_lib = Path::new(&out_dir).join("onnxruntime"); + + // link the library + println!("cargo:rustc-env=ORT_LIB_LOCATION={}", output_lib.display()); +} diff --git a/modules/c-wrapper/corss-build b/modules/c-wrapper/corss-build new file mode 100644 index 0000000..9cc8068 --- /dev/null +++ b/modules/c-wrapper/corss-build @@ -0,0 +1,45 @@ +FROM rust:1.81 + +# Install necessary tools and dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + clang \ + cmake \ + curl \ + file \ + git \ + libssl-dev \ + pkg-config \ + python3 \ + qemu-user-static \ + wget \ + xz-utils \ + zlib1g-dev \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +# Install Docker (for cross) +RUN curl -fsSL https://get.docker.com | sh + +# Install Cross +RUN cargo install cross --git https://github.com/cross-rs/cross + +# Add macOS cross-compilation toolchain using osxcross +# RUN git clone https://github.com/tpoechtrager/osxcross.git /osxcross \ +# && cd ../osxcross \ +# && wget -q https://github.com/tpoechtrager/osxcross/releases/download/v1.1/MacOSX10.11.sdk.tar.xz -O tarballs/MacOSX10.11.sdk.tar.xz \ +# && UNATTENDED=1 ./build.sh + +# ENV PATH="/osxcross/target/bin:$PATH" +# ENV CROSS_CONTAINER_IN_CONTAINER=true +# ENV MACOSX_DEPLOYMENT_TARGET=11.0 + +# # Add Windows cross-compilation toolchain +# RUN rustup target add x86_64-pc-windows-gnu aarch64-pc-windows-gnu + +# # Add Linux cross-compilation toolchains +# RUN rustup target add x86_64-unknown-linux-gnu aarch64-unknown-linux-gnu + +# Set up entrypoint for container +WORKDIR /project + +ENTRYPOINT ["/bin/bash"] \ No newline at end of file diff --git a/modules/c-wrapper/scripts/build-docker.sh b/modules/c-wrapper/scripts/build-docker.sh new file mode 100644 index 0000000..7ffdba6 --- /dev/null +++ b/modules/c-wrapper/scripts/build-docker.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# wipe and build the build context +BUILD_DIR="build-context" +if [ -d "$BUILD_DIR" ]; then + echo "Cleaning up existing build directory..." + rm -rf "$BUILD_DIR" +fi +mkdir "$BUILD_DIR" +mkdir "$BUILD_DIR"/c-wrapper + +# copy over the code to be built +cp -r src "$BUILD_DIR"/c-wrapper/src +cp -r tests "$BUILD_DIR"/c-wrapper/tests +cp -r scripts "$BUILD_DIR"/c-wrapper/scripts +cp Cargo.toml "$BUILD_DIR"/c-wrapper/Cargo.toml +cp build.rs "$BUILD_DIR"/c-wrapper/build.rs +cp -r ../core "$BUILD_DIR"/core +cp Dockerfile "$BUILD_DIR"/Dockerfile + +# remove unnecessary files +rm -rf "$BUILD_DIR"/core/.git +rm -rf "$BUILD_DIR"/core/target/ + +# build the docker image +cd "$BUILD_DIR" +docker build --no-cache -t c-wrapper-tests . + +docker run c-wrapper-tests +# docker run -it c-wrapper-tests /bin/bash diff --git a/modules/c-wrapper/scripts/copy_over_lib.sh b/modules/c-wrapper/scripts/copy_over_lib.sh new file mode 100644 index 0000000..135212d --- /dev/null +++ b/modules/c-wrapper/scripts/copy_over_lib.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. +OS=$(uname) + +# Set the library name and extension based on the OS +case "$OS" in + "Linux") + LIB_NAME="libc_wrapper.so" + ;; + "Darwin") + LIB_NAME="libc_wrapper.dylib" + ;; + "CYGWIN"*|"MINGW"*) + LIB_NAME="libc_wrapper.dll" + ;; + *) + echo "Unsupported operating system: $OS" + exit 1 + ;; +esac + +# Source directory (where Cargo outputs the compiled library) +SOURCE_DIR="target/debug" + +# Destination directory (tests directory) +DEST_DIR="tests/test_utils" + +# Destination directory (onnxruntime library) +LIB_PATH="onnx_lib/onnxruntime" + + +cp "$SOURCE_DIR/$LIB_NAME" "$DEST_DIR/" +cp "$LIB_PATH" "$DEST_DIR/" diff --git a/modules/c-wrapper/scripts/prep_tests.sh b/modules/c-wrapper/scripts/prep_tests.sh new file mode 100644 index 0000000..9bcd32a --- /dev/null +++ b/modules/c-wrapper/scripts/prep_tests.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +# download onnxruntime +# Detect operating system +OS=$(uname -s | tr '[:upper:]' '[:lower:]') + +# Detect architecture +ARCH=$(uname -m) + +# Download the correct onnxruntime +if [ "$ARCH" == "x86_64" ] && [ "$OS" == "linux" ]; then + wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz + tar -xvf onnxruntime-linux-x64-1.20.0.tgz + mv onnxruntime-linux-x64-1.20.0 tests/test_utils/onnxruntime +else + echo "Unsupported operating system and arch: $OS $ARCH" + exit 1 +fi + +export ORT_LIB_LOCATION=$(pwd)/tests/test_utils/onnxruntime/lib +export LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +cargo build + +# Get the operating system +OS=$(uname) + +# Set the library name and extension based on the OS +case "$OS" in + "Linux") + LIB_NAME="libc_wrapper.so" + ;; + "Darwin") + LIB_NAME="libc_wrapper.dylib" + ;; + "CYGWIN"*|"MINGW"*) + LIB_NAME="libc_wrapper.dll" + ;; + *) + echo "Unsupported operating system: $OS" + exit 1 + ;; +esac + +# Source directory (where Cargo outputs the compiled library) +SOURCE_DIR="target/debug" + +# Destination directory (tests directory) +DEST_DIR="tests/test_utils" + + +# Copy the library to the tests directory +if [ -f "$SOURCE_DIR/$LIB_NAME" ]; then + cp "$SOURCE_DIR/$LIB_NAME" "$DEST_DIR/" + echo "Copied $LIB_NAME to $DEST_DIR" +else + echo "Library not found: $SOURCE_DIR/$LIB_NAME" + exit 1 +fi diff --git a/modules/c-wrapper/scripts/run_cross.sh b/modules/c-wrapper/scripts/run_cross.sh new file mode 100644 index 0000000..c9ef568 --- /dev/null +++ b/modules/c-wrapper/scripts/run_cross.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +dockerd +sudo systemctl start docker +cross build --target aarch64-unknown-linux-gnu + + +docker run --rm -it \ + -v "$(pwd):/project" \ # Mount the current directory to /project + -v /var/run/docker.sock:/var/run/docker.sock \ # Share host Docker socket + -w /project \ # Set the working directory inside the container + rust-cross-compiler diff --git a/modules/c-wrapper/scripts/run_tests.sh b/modules/c-wrapper/scripts/run_tests.sh new file mode 100644 index 0000000..cfe54f0 --- /dev/null +++ b/modules/c-wrapper/scripts/run_tests.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd .. + +cd tests + +python3 -m unittest discover . diff --git a/modules/c-wrapper/src/api/execution/buffered_compute.rs b/modules/c-wrapper/src/api/execution/buffered_compute.rs new file mode 100644 index 0000000..4ac7e5e --- /dev/null +++ b/modules/c-wrapper/src/api/execution/buffered_compute.rs @@ -0,0 +1,164 @@ +//! This module contains the buffered_compute function that is called from the C API to compute the model. +use crate::state::STATE; +use std::ffi::{c_float, CStr, CString, c_int, c_char}; +use surrealml_core::execution::compute::ModelComputation; +use crate::utils::Vecf32Return; +use std::collections::HashMap; + + +/// Computes the model with the given data. +/// +/// # Arguments +/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model. +/// * `data_ptr` - A pointer to the data to compute. +/// * `length` - The length of the data. +/// * `strings` - A pointer to an array of strings to use as keys for the data. +/// * `string_count` - The number of strings in the array. +/// +/// # Returns +/// A Vecf32Return object containing the outcome of the computation. +#[no_mangle] +pub extern "C" fn buffered_compute( + file_id_ptr: *const c_char, + data_ptr: *const c_float, + data_length: usize, + strings: *const *const c_char, + string_count: c_int +) -> Vecf32Return { + if file_id_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("File id is null").unwrap().into_raw() + } + } + if data_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("Data is null").unwrap().into_raw() + } + } + + let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() { + Ok(file_id) => file_id.to_owned(), + Err(error) => return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw() + } + }; + + if strings.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("string pointer is null").unwrap().into_raw() + } + } + + // extract the list of strings from the C array + let string_count = string_count as usize; + let c_strings = unsafe { std::slice::from_raw_parts(strings, string_count) }; + let rust_strings: Vec = c_strings + .iter() + .map(|&s| { + if s.is_null() { + String::new() + } else { + unsafe { CStr::from_ptr(s).to_string_lossy().into_owned() } + } + }) + .collect(); + for i in rust_strings.iter() { + if i.is_empty() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("null string passed in as key").unwrap().into_raw() + } + } + } + + let data_slice = unsafe { std::slice::from_raw_parts(data_ptr, data_length) }; + + if rust_strings.len() != data_slice.len() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("String count does not match data length").unwrap().into_raw() + } + } + + // stitch the strings and data together + let mut input_map = HashMap::new(); + for (i, key) in rust_strings.iter().enumerate() { + input_map.insert(key.clone(), data_slice[i]); + } + + let mut state = match STATE.lock() { + Ok(state) => state, + Err(error) => { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw() + } + } + }; + let mut file = match state.get_mut(&file_id) { + Some(file) => file, + None => { + { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw() + } + } + } + }; + let compute_unit = ModelComputation { + surml_file: &mut file + }; + match compute_unit.buffered_compute(&mut input_map) { + Ok(mut output) => { + let output_len = output.len(); + let output_capacity = output.capacity(); + let output_ptr = output.as_mut_ptr(); + std::mem::forget(output); + Vecf32Return { + data: output_ptr, + length: output_len, + capacity: output_capacity, + is_error: 0, + error_message: std::ptr::null_mut() + } + }, + Err(error) => { + Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error computing model: {}", error)).unwrap().into_raw() + } + } + } +} \ No newline at end of file diff --git a/modules/c-wrapper/src/api/execution/mod.rs b/modules/c-wrapper/src/api/execution/mod.rs new file mode 100644 index 0000000..590975c --- /dev/null +++ b/modules/c-wrapper/src/api/execution/mod.rs @@ -0,0 +1,3 @@ +//! The C API for executing ML models. +pub mod raw_compute; +pub mod buffered_compute; diff --git a/modules/c-wrapper/src/api/execution/raw_compute.rs b/modules/c-wrapper/src/api/execution/raw_compute.rs new file mode 100644 index 0000000..9f45038 --- /dev/null +++ b/modules/c-wrapper/src/api/execution/raw_compute.rs @@ -0,0 +1,108 @@ +//! This module contains the raw_compute function that is called from the C API to compute the model. +use crate::state::STATE; +use std::ffi::{c_float, CStr, CString, c_char}; +use surrealml_core::execution::compute::ModelComputation; +use crate::utils::Vecf32Return; + + +/// Computes the model with the given data. +/// +/// # Arguments +/// * `file_id_ptr` - A pointer to the unique identifier for the loaded model. +/// * `data_ptr` - A pointer to the data to compute. +/// * `length` - The length of the data. +/// +/// # Returns +/// A Vecf32Return object containing the outcome of the computation. +#[no_mangle] +pub extern "C" fn raw_compute(file_id_ptr: *const c_char, data_ptr: *const c_float, length: usize) -> Vecf32Return { + + if file_id_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("File id is null").unwrap().into_raw() + } + } + if data_ptr.is_null() { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new("Data is null").unwrap().into_raw() + } + } + + let file_id = match unsafe { CStr::from_ptr(file_id_ptr) }.to_str() { + Ok(file_id) => file_id.to_owned(), + Err(error) => return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting file id: {}", error)).unwrap().into_raw() + } + }; + + let mut state = match STATE.lock() { + Ok(state) => state, + Err(error) => { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error getting state: {}", error)).unwrap().into_raw() + } + } + }; + + let mut file = match state.get_mut(&file_id) { + Some(file) => file, + None => { + { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("File not found for id: {}, here is the state: {:?}", file_id, state.keys())).unwrap().into_raw() + } + } + } + }; + + let slice = unsafe { std::slice::from_raw_parts(data_ptr, length) }; + let tensor = ndarray::arr1(slice).into_dyn(); + let compute_unit = ModelComputation { + surml_file: &mut file + }; + + // perform the computation + let mut outcome = match compute_unit.raw_compute(tensor, None) { + Ok(outcome) => outcome, + Err(error) => { + return Vecf32Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Error computing model: {}", error.message)).unwrap().into_raw() + } + } + }; + let outcome_ptr = outcome.as_mut_ptr(); + let outcome_len = outcome.len(); + let outcome_capacity = outcome.capacity(); + std::mem::forget(outcome); + Vecf32Return { + data: outcome_ptr, + length: outcome_len, + capacity: outcome_capacity, + is_error: 0, + error_message: std::ptr::null_mut() + } +} diff --git a/modules/c-wrapper/src/api/ml_sys/link_onnx.rs b/modules/c-wrapper/src/api/ml_sys/link_onnx.rs new file mode 100644 index 0000000..0760a12 --- /dev/null +++ b/modules/c-wrapper/src/api/ml_sys/link_onnx.rs @@ -0,0 +1,43 @@ +use surrealml_core::execution::session::set_environment; +use std::ffi::{c_float, CStr, CString, c_int, c_char}; +use crate::utils::EmptyReturn; + + +/// Links the onnx file to the environment. +/// +/// # Arguments +/// * `onnx_path` - The path to the onnx file. +/// +/// # Returns +/// An EmptyReturn object containing the outcome of the operation. +#[no_mangle] +pub extern "C" fn link_onnx(onnx_path: *const c_char) -> EmptyReturn { + if onnx_path.is_null() { + return EmptyReturn { + is_error: 1, + error_message: CString::new("Onnx path is null").unwrap().into_raw() + } + } + let onnx_path = match unsafe { CStr::from_ptr(onnx_path) }.to_str() { + Ok(onnx_path) => onnx_path.to_owned(), + Err(error) => return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Error getting onnx path: {}", error)).unwrap().into_raw() + } + }; + match set_environment(onnx_path) { + Ok(_) => { + EmptyReturn { + is_error: 0, + error_message: std::ptr::null_mut() + } + }, + Err(e) => { + println!("Error linking onnx file to environment: {}", e); + EmptyReturn { + is_error: 1, + error_message: CString::new(e.to_string()).unwrap().into_raw() + } + } + } +} diff --git a/modules/c-wrapper/src/api/ml_sys/mod.rs b/modules/c-wrapper/src/api/ml_sys/mod.rs new file mode 100644 index 0000000..1d6da62 --- /dev/null +++ b/modules/c-wrapper/src/api/ml_sys/mod.rs @@ -0,0 +1 @@ +pub mod link_onnx; diff --git a/modules/c-wrapper/src/api/mod.rs b/modules/c-wrapper/src/api/mod.rs new file mode 100644 index 0000000..1f0b9fa --- /dev/null +++ b/modules/c-wrapper/src/api/mod.rs @@ -0,0 +1,4 @@ +//! C API for interacting with the SurML file storage and executing models. +pub mod execution; +pub mod storage; +pub mod ml_sys; diff --git a/modules/c-wrapper/src/api/storage/load_cached_raw_model.rs b/modules/c-wrapper/src/api/storage/load_cached_raw_model.rs new file mode 100644 index 0000000..5f4af21 --- /dev/null +++ b/modules/c-wrapper/src/api/storage/load_cached_raw_model.rs @@ -0,0 +1,37 @@ +//! Defines the C interface for loading an ONNX model from a file and storing it in memory. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::fs::File; +use std::io::Read; +use std::os::raw::c_char; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// Local module imports +use crate::state::{generate_unique_id, STATE}; +use crate::utils::StringReturn; +use crate::{process_string_for_string_return, string_return_safe_eject}; + + + +/// Loads a ONNX model from a file wrapping it in a SurMlFile struct +/// which is stored in memory and referenced by a unique ID. +/// +/// # Arguments +/// * `file_path` - The path to the file to load. +/// +/// # Returns +/// A unique identifier for the loaded model. +#[no_mangle] +pub extern "C" fn load_cached_raw_model(file_path_ptr: *const c_char) -> StringReturn { + let file_path_str = process_string_for_string_return!(file_path_ptr, "file path"); + let file_id = generate_unique_id(); + let mut model = string_return_safe_eject!(File::open(file_path_str)); + let mut data = vec![]; + string_return_safe_eject!(model.read_to_end(&mut data)); + let file = SurMlFile::fresh(data); + let mut python_state = STATE.lock().unwrap(); + python_state.insert(file_id.clone(), file); + StringReturn::success(file_id) +} diff --git a/modules/c-wrapper/src/api/storage/load_model.rs b/modules/c-wrapper/src/api/storage/load_model.rs new file mode 100644 index 0000000..bb33e4f --- /dev/null +++ b/modules/c-wrapper/src/api/storage/load_model.rs @@ -0,0 +1,135 @@ +//! Defines the C interface for loading a surml file and getting the meta data around the model. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::{c_char, c_int}; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// Local module imports +use crate::state::{generate_unique_id, STATE}; + + +/// Holds the data around the outcome of the load_model function. +/// +/// # Fields +/// * `file_id` - The unique identifier for the loaded model. +/// * `name` - The name of the model. +/// * `description` - The description of the model. +/// * `version` - The version of the model. +/// * `error_message` - An error message if the loading failed. +/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success). +#[repr(C)] +pub struct FileInfo { + pub file_id: *mut c_char, + pub name: *mut c_char, + pub description: *mut c_char, + pub version: *mut c_char, + pub error_message: *mut c_char, + pub is_error: c_int, +} + + +/// Frees the memory allocated for the file info. +/// +/// # Arguments +/// * `info` - The file info to free. +#[no_mangle] +pub extern "C" fn free_file_info(info: FileInfo) { + // Free all allocated strings if they are not null + if !info.file_id.is_null() { + unsafe { drop(CString::from_raw(info.file_id)) }; + } + if !info.name.is_null() { + unsafe { drop(CString::from_raw(info.name)) }; + } + if !info.description.is_null() { + unsafe { drop(CString::from_raw(info.description)) }; + } + if !info.version.is_null() { + unsafe { drop(CString::from_raw(info.version)) }; + } + if !info.error_message.is_null() { + unsafe { drop(CString::from_raw(info.error_message)) }; + } +} + +/// Loads a model from a file and returns a unique identifier for the loaded model. +/// +/// # Arguments +/// * `file_path_ptr` - A pointer to the file path of the model to load. +/// +/// # Returns +/// Meta data around the model and a unique identifier for the loaded model. +#[no_mangle] +pub extern "C" fn load_model(file_path_ptr: *const c_char) -> FileInfo { + + // checking that the file path pointer is not null + if file_path_ptr.is_null() { + return FileInfo { + file_id: std::ptr::null_mut(), + name: std::ptr::null_mut(), + description: std::ptr::null_mut(), + version: std::ptr::null_mut(), + error_message: CString::new("Received a null pointer for file path").unwrap().into_raw(), + is_error: 1 + }; + } + + // Convert the raw C string to a Rust string + let c_str = unsafe { CStr::from_ptr(file_path_ptr) }; + + // convert the CStr into a &str + let file_path = match c_str.to_str() { + Ok(rust_str) => rust_str, + Err(_) => { + return FileInfo { + file_id: std::ptr::null_mut(), + name: std::ptr::null_mut(), + description: std::ptr::null_mut(), + version: std::ptr::null_mut(), + error_message: CString::new("Invalid UTF-8 string received for file path").unwrap().into_raw(), + is_error: 1 + }; + } + }; + + let file = match SurMlFile::from_file(&file_path) { + Ok(file) => file, + Err(e) => { + return FileInfo { + file_id: std::ptr::null_mut(), + name: std::ptr::null_mut(), + description: std::ptr::null_mut(), + version: std::ptr::null_mut(), + error_message: CString::new(e.to_string()).unwrap().into_raw(), + is_error: 1 + }; + } + }; + + // get the meta data from the file + let name = file.header.name.to_string(); + let description = file.header.description.to_string(); + let version = file.header.version.to_string(); + + // insert the file into the state + let file_id = generate_unique_id(); + let mut state = STATE.lock().unwrap(); + state.insert(file_id.clone(), file); + + // return the meta data + let file_id = CString::new(file_id).unwrap(); + let name = CString::new(name).unwrap(); + let description = CString::new(description).unwrap(); + let version = CString::new(version).unwrap(); + + FileInfo { + file_id: file_id.into_raw(), + name: name.into_raw(), + description: description.into_raw(), + version: version.into_raw(), + error_message: std::ptr::null_mut(), + is_error: 0 + } +} \ No newline at end of file diff --git a/modules/c-wrapper/src/api/storage/meta.rs b/modules/c-wrapper/src/api/storage/meta.rs new file mode 100644 index 0000000..d4fb797 --- /dev/null +++ b/modules/c-wrapper/src/api/storage/meta.rs @@ -0,0 +1,211 @@ +//! Defines the C API interface for interacting with the meta data of a SurML file. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// External crate imports +use surrealml_core::storage::header::normalisers::wrapper::NormaliserType; + +// Local module imports +use crate::state::STATE; +use crate::utils::EmptyReturn; +use crate::{empty_return_safe_eject, process_string_for_empty_return}; + + + +/// Adds a name to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `model_name` - The name of the model to be added. +#[no_mangle] +pub extern "C" fn add_name(file_id_ptr: *const c_char, model_name_ptr: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let model_name = process_string_for_empty_return!(model_name_ptr, "model name"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_name(model_name); + EmptyReturn::success() +} + + +/// Adds a description to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `description` - The description of the model to be added. +#[no_mangle] +pub extern "C" fn add_description(file_id_ptr: *const c_char, description_ptr: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let description = process_string_for_empty_return!(description_ptr, "description"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_description(description); + EmptyReturn::success() +} + + +/// Adds a version to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `version` - The version of the model to be added. +#[no_mangle] +pub extern "C" fn add_version(file_id: *const c_char, version: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let version = process_string_for_empty_return!(version, "version"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + let _ = wrapped_file.header.add_version(version); + EmptyReturn::success() +} + + +/// Adds a column to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `column_name` - The name of the column to be added. +#[no_mangle] +pub extern "C" fn add_column(file_id: *const c_char, column_name: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let column_name = process_string_for_empty_return!(column_name, "column name"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_column(column_name); + EmptyReturn::success() +} + + +/// adds an author to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `author` - The author to be added. +#[no_mangle] +pub extern "C" fn add_author(file_id: *const c_char, author: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let author = process_string_for_empty_return!(author, "author"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_author(author); + EmptyReturn::success() +} + + +/// Adds an origin of where the model was trained to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `origin` - The origin to be added. +#[no_mangle] +pub extern "C" fn add_origin(file_id: *const c_char, origin: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let origin = process_string_for_empty_return!(origin, "origin"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + let _ = wrapped_file.header.add_origin(origin); + EmptyReturn::success() +} + + +/// Adds an engine to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `engine` - The engine to be added. +#[no_mangle] +pub extern "C" fn add_engine(file_id: *const c_char, engine: *const c_char) -> EmptyReturn { + let file_id = process_string_for_empty_return!(file_id, "file id"); + let engine = process_string_for_empty_return!(engine, "engine"); + let mut state = STATE.lock().unwrap(); + let wrapped_file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + wrapped_file.header.add_engine(engine); + EmptyReturn::success() +} + + +/// Adds an output to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `output_name` - The name of the output to be added. +/// * `normaliser_label` (Optional) - The label of the normaliser to be applied to the output. +/// * `one` (Optional) - The first parameter of the normaliser. +/// * `two` (Optional) - The second parameter of the normaliser. +#[no_mangle] +pub extern "C" fn add_output( + file_id_ptr: *const c_char, + output_name_ptr: *const c_char, + normaliser_label_ptr: *const c_char, + one: *const c_char, + two: *const c_char +) -> EmptyReturn { + + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let output_name = process_string_for_empty_return!(output_name_ptr, "output name"); + + let normaliser_label = if normaliser_label_ptr.is_null() { + None + } + else { + Some(process_string_for_empty_return!(normaliser_label_ptr, "normaliser label")) + }; + + let one = if one.is_null() { + None + } + else { + Some( + empty_return_safe_eject!(process_string_for_empty_return!(one, "one").parse::()) + ) + }; + let two = if two.is_null() { + None + } + else { + Some( + empty_return_safe_eject!(process_string_for_empty_return!(two, "two").parse::()) + ) + }; + + let mut state = STATE.lock().unwrap(); + let file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + if let Some(normaliser_label) = normaliser_label { + let normaliser = NormaliserType::new(normaliser_label, one.unwrap(), two.unwrap()); + file.header.add_output(output_name, Some(normaliser)); + } + else { + file.header.add_output(output_name, None); + } + EmptyReturn::success() +} + + +/// Adds a normaliser to the SurMlFile struct. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// * `column_name` - The name of the column to which the normaliser will be applied. +/// * `normaliser_label` - The label of the normaliser to be applied to the column. +/// * `one` - The first parameter of the normaliser. +/// * `two` - The second parameter of the normaliser. +#[no_mangle] +pub extern "C" fn add_normaliser( + file_id_ptr: *const c_char, + column_name_ptr: *const c_char, + normaliser_label_ptr: *const c_char, + one: f32, + two: f32 +) -> EmptyReturn { + + let file_id = process_string_for_empty_return!(file_id_ptr, "file id"); + let column_name = process_string_for_empty_return!(column_name_ptr, "column name"); + let normaliser_label = process_string_for_empty_return!(normaliser_label_ptr, "normaliser label"); + + let normaliser = NormaliserType::new(normaliser_label, one, two); + let mut state = STATE.lock().unwrap(); + let file = empty_return_safe_eject!(state.get_mut(&file_id), "Model not found", Option); + let _ = file.header.normalisers.add_normaliser(normaliser, column_name, &file.header.keys); + EmptyReturn::success() +} \ No newline at end of file diff --git a/modules/c-wrapper/src/api/storage/mod.rs b/modules/c-wrapper/src/api/storage/mod.rs new file mode 100644 index 0000000..f049a57 --- /dev/null +++ b/modules/c-wrapper/src/api/storage/mod.rs @@ -0,0 +1,7 @@ +//! C Storage API +pub mod load_model; +pub mod save_model; +pub mod load_cached_raw_model; +pub mod to_bytes; +pub mod meta; +pub mod upload_model; diff --git a/modules/c-wrapper/src/api/storage/save_model.rs b/modules/c-wrapper/src/api/storage/save_model.rs new file mode 100644 index 0000000..7839999 --- /dev/null +++ b/modules/c-wrapper/src/api/storage/save_model.rs @@ -0,0 +1,32 @@ +//! Save a model to a file, deleting the file from the `STATE` in the process. +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// Local module imports +use crate::state::STATE; +use crate::utils::EmptyReturn; +use crate::{empty_return_safe_eject, process_string_for_empty_return}; + + +/// Saves a model to a file, deleting the file from the `PYTHON_STATE` in the process. +/// +/// # Arguments +/// * `file_path` - The path to the file to save to. +/// * `file_id` - The unique identifier for the loaded model. +/// +/// # Returns +/// An empty return object indicating success or failure. +#[no_mangle] +pub extern "C" fn save_model(file_path_ptr: *const c_char, file_id_ptr: *const c_char) -> EmptyReturn { + let file_path_str = process_string_for_empty_return!(file_path_ptr, "file path"); + let file_id_str = process_string_for_empty_return!(file_id_ptr, "file id"); + let mut state = STATE.lock().unwrap(); + let file: &mut SurMlFile = empty_return_safe_eject!(state.get_mut(&file_id_str), "Model not found", Option); + empty_return_safe_eject!(file.write(&file_path_str)); + state.remove(&file_id_str); + EmptyReturn::success() +} diff --git a/modules/c-wrapper/src/api/storage/to_bytes.rs b/modules/c-wrapper/src/api/storage/to_bytes.rs new file mode 100644 index 0000000..44de0d2 --- /dev/null +++ b/modules/c-wrapper/src/api/storage/to_bytes.rs @@ -0,0 +1,27 @@ +//! convert the entire SurML file to bytes +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// Local module imports +use crate::state::STATE; +use crate::utils::VecU8Return; +use crate::process_string_for_vec_u8_return; + + + +/// Converts the entire SurML file to bytes. +/// +/// # Arguments +/// * `file_id` - The unique identifier for the SurMlFile struct. +/// +/// # Returns +/// A vector of bytes representing the entire file. +#[no_mangle] +pub extern "C" fn to_bytes(file_id_ptr: *const c_char) -> VecU8Return { + let file_id = process_string_for_vec_u8_return!(file_id_ptr, "file id"); + let mut state = STATE.lock().unwrap(); + let file = state.get_mut(&file_id).unwrap(); + let raw_bytes = file.to_bytes(); + VecU8Return::success(raw_bytes) +} diff --git a/modules/c-wrapper/src/api/storage/upload_model.rs b/modules/c-wrapper/src/api/storage/upload_model.rs new file mode 100644 index 0000000..098aae9 --- /dev/null +++ b/modules/c-wrapper/src/api/storage/upload_model.rs @@ -0,0 +1,86 @@ +// Standard library imports +use std::ffi::{CStr, CString}; +use std::os::raw::c_char; + +// External crate imports +use base64::encode; +use hyper::{ + Body, Client, Method, Request, Uri, + header::{AUTHORIZATION, CONTENT_TYPE, HeaderValue}, +}; +use surrealml_core::storage::stream_adapter::StreamAdapter; + +// Local module imports +use crate::utils::EmptyReturn; +use crate::{empty_return_safe_eject, process_string_for_empty_return}; + + +/// Uploads a model to a remote server. +/// +/// # Arguments +/// * `file_path_ptr` - The path to the file to upload. +/// * `url_ptr` - The URL to upload the file to. +/// * `chunk_size` - The size of the chunks to upload the file in. +/// * `ns_ptr` - The namespace to upload the file to. +/// * `db_ptr` - The database to upload the file to. +/// * `username_ptr` - The username to use for authentication. +/// * `password_ptr` - The password to use for authentication. +/// +/// # Returns +/// An empty return object indicating success or failure. +#[no_mangle] +pub extern "C" fn upload_model( + file_path_ptr: *const c_char, + url_ptr: *const c_char, + chunk_size: usize, + ns_ptr: *const c_char, + db_ptr: *const c_char, + username_ptr: *const c_char, + password_ptr: *const c_char +) -> EmptyReturn { + // process the inputs + let file_path = process_string_for_empty_return!(file_path_ptr, "file path"); + let url = process_string_for_empty_return!(url_ptr, "url"); + let ns = process_string_for_empty_return!(ns_ptr, "namespace"); + let db = process_string_for_empty_return!(db_ptr, "database"); + let username = match username_ptr.is_null() { + true => None, + false => Some(process_string_for_empty_return!(username_ptr, "username")) + }; + let password = match password_ptr.is_null() { + true => None, + false => Some(process_string_for_empty_return!(password_ptr, "password")) + }; + + let client = Client::new(); + + let uri: Uri = empty_return_safe_eject!(url.parse()); + let generator = empty_return_safe_eject!(StreamAdapter::new(chunk_size, file_path)); + let body = Body::wrap_stream(generator); + + let part_req = Request::builder() + .method(Method::POST) + .uri(uri) + .header(CONTENT_TYPE, "application/octet-stream") + .header("surreal-ns", empty_return_safe_eject!(HeaderValue::from_str(&ns))) + .header("surreal-db", empty_return_safe_eject!(HeaderValue::from_str(&db))); + + let req; + if username.is_none() == false && password.is_none() == false { + // unwraps are safe because we have already checked that the values are not None + let encoded_credentials = encode(format!("{}:{}", username.unwrap(), password.unwrap())); + req = empty_return_safe_eject!(part_req.header(AUTHORIZATION, format!("Basic {}", encoded_credentials)) + .body(body)); + } + else { + req = empty_return_safe_eject!(part_req.body(body)); + } + + let tokio_runtime = empty_return_safe_eject!(tokio::runtime::Builder::new_current_thread().enable_io() + .enable_time() + .build()); + tokio_runtime.block_on( async move { + let _response = client.request(req).await.unwrap(); + }); + EmptyReturn::success() +} \ No newline at end of file diff --git a/modules/c-wrapper/src/lib.rs b/modules/c-wrapper/src/lib.rs new file mode 100644 index 0000000..b44dfd2 --- /dev/null +++ b/modules/c-wrapper/src/lib.rs @@ -0,0 +1,4 @@ +//! C lib for interacting with the SurML file storage and executing models. +mod state; +mod api; +mod utils; diff --git a/modules/c-wrapper/src/state.rs b/modules/c-wrapper/src/state.rs new file mode 100644 index 0000000..e8b56c4 --- /dev/null +++ b/modules/c-wrapper/src/state.rs @@ -0,0 +1,30 @@ +//! Defines operations for handling memory of a python program that is accessing the rust library. +// Standard library imports +use std::collections::HashMap; +use std::sync::{Arc, LazyLock, Mutex}; + +// External crate imports +use surrealml_core::storage::surml_file::SurMlFile; + +// External library imports +use uuid::Uuid; + + +/// A hashmap of unique identifiers to loaded machine learning models. As long as the python program keeps the unique +/// identifier it can access the loaded machine learning model. It is best to keep as little as possible on the python +/// side and keep as much as possible on the rust side. Therefore bindings to other languages can be created with ease +/// and a command line tool can also be created without much need for new features. This will also ensure consistency +/// between other languages and the command line tool. +pub static STATE: LazyLock>>> = LazyLock::new(|| { + Arc::new(Mutex::new(HashMap::new())) +}); + + +/// Generates a unique identifier that can be used to access a loaded machine learning model. +/// +/// # Returns +/// A unique identifier that can be used to access a loaded machine learning model. +pub fn generate_unique_id() -> String { + let uuid = Uuid::new_v4(); + uuid.to_string() +} diff --git a/modules/c-wrapper/src/utils.rs b/modules/c-wrapper/src/utils.rs new file mode 100644 index 0000000..693093c --- /dev/null +++ b/modules/c-wrapper/src/utils.rs @@ -0,0 +1,363 @@ +//! Defines macros and C structs for reducing the amount of boilerplate code required for the C API. +use std::os::raw::{c_char, c_int}; +use std::ffi::CString; + + +/// Checks that the pointer to the string is not null and converts to a Rust string. Any errors are returned as an `EmptyReturn`. +/// +/// # Arguments +/// * `str_ptr` - The pointer to the string. +/// * `var_name` - The name of the variable being processed (for error messages). +#[macro_export] +macro_rules! process_string_for_empty_return { + ($str_ptr:expr, $var_name:expr) => { + match $str_ptr.is_null() { + true => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Received a null pointer for {}", $var_name)).unwrap().into_raw() + }; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw() + }; + } + } + } + } + }; + ($str_ptr:expr, $var_name:expr, Option) => { + match $str_ptr.is_null() { + true => { + return None; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => Some(s.to_owned()), + Err(_) => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw() + }; + } + } + } + } + } +} + +/// Checks that the pointer to the string is not null and converts to a Rust string. Any errors are returned as a `StringReturn`. +/// +/// # Arguments +/// * `str_ptr` - The pointer to the string. +/// * `var_name` - The name of the variable being processed (for error messages). +#[macro_export] +macro_rules! process_string_for_string_return { + ($str_ptr:expr, $var_name:expr) => { + match $str_ptr.is_null() { + true => { + return StringReturn { + is_error: 1, + error_message: CString::new(format!("Received a null pointer for {}", $var_name)).unwrap().into_raw(), + string: std::ptr::null_mut() + }; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + return StringReturn { + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw(), + string: std::ptr::null_mut() + }; + } + } + } + } + }; +} + + +/// Checks that the pointer to the string is not null and converts to a Rust string. Any errors are returned as a `VecU8Return`. +/// +/// # Arguments +/// * `str_ptr` - The pointer to the string. +/// * `var_name` - The name of the variable being processed (for error messages). +#[macro_export] +macro_rules! process_string_for_vec_u8_return { + ($str_ptr:expr, $var_name:expr) => { + match $str_ptr.is_null() { + true => { + return VecU8Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Received a null pointer for {}", $var_name)).unwrap().into_raw() + }; + }, + false => { + let c_str = unsafe { CStr::from_ptr($str_ptr) }; + match c_str.to_str() { + Ok(s) => s.to_owned(), + Err(_) => { + return VecU8Return { + data: std::ptr::null_mut(), + length: 0, + capacity: 0, + is_error: 1, + error_message: CString::new(format!("Invalid UTF-8 string received for {}", $var_name)).unwrap().into_raw() + }; + } + } + } + } + }; +} + + +/// Checks the result of an execution and returns an `StringReturn` if an error occurred. +/// +/// # Arguments +/// * `execution` - The execution such as a function call to map to `StringReturn` if an error occurred. +#[macro_export] +macro_rules! string_return_safe_eject { + ($execution:expr) => { + match $execution { + Ok(s) => s, + Err(e) => { + return StringReturn { + string: std::ptr::null_mut(), + is_error: 1, + error_message: CString::new(e.to_string()).unwrap().into_raw() + } + } + } + }; +} + + +/// Checks the result of an execution and returns an `EmptyReturn` if an error occurred or a none is returned. +/// +/// # Arguments +/// * `execution` - The execution such as a function call to map to `EmptyReturn` if an error occurred. +/// * `var` - The variable name to include in the error message. +/// * `Option` - The type of the execution. +/// +/// # Arguments +/// * `execution` - The execution such as a function call to map to `EmptyReturn` if an error occurred. +#[macro_export] +macro_rules! empty_return_safe_eject { + ($execution:expr, $var:expr, Option) => { + match $execution { + Some(s) => s, + None => { + return EmptyReturn { + is_error: 1, + error_message: CString::new($var).unwrap().into_raw() + } + } + } + }; + ($execution:expr) => { + match $execution { + Ok(s) => s, + Err(e) => { + return EmptyReturn { + is_error: 1, + error_message: CString::new(e.to_string()).unwrap().into_raw() + } + } + } + }; +} + + +/// Returns a simple String to the caller. +/// +/// # Fields +/// * `string` - The string to return. +/// * `is_error` - A flag indicating if an error occurred (1 if error 0 if not). +/// * `error_message` - An optional error message. +#[repr(C)] +pub struct StringReturn { + pub string: *mut c_char, + pub is_error: c_int, + pub error_message: *mut c_char +} + + +impl StringReturn { + + /// Returns a new `StringReturn` object with the string and no error. + /// + /// # Arguments + /// * `string` - The string to return. + /// + /// # Returns + /// A new `StringReturn` object. + pub fn success(string: String) -> Self { + StringReturn { + string: CString::new(string).unwrap().into_raw(), + is_error: 0, + error_message: std::ptr::null_mut() + } + } +} + + +/// Frees the memory allocated for the `StringReturn` object. +/// +/// # Arguments +/// * `string_return` - The `StringReturn` object to free. +#[no_mangle] +pub extern "C" fn free_string_return(string_return: StringReturn) { + // Free the string if it is not null + if !string_return.string.is_null() { + unsafe { drop(CString::from_raw(string_return.string)) }; + } + // Free the error message if it is not null + if !string_return.error_message.is_null() { + unsafe { drop(CString::from_raw(string_return.error_message)) }; + } +} + + +/// Returns a simple empty return object to the caller. +/// +/// # Fields +/// * `is_error` - A flag indicating if an error occurred (1 if error 0 if not). +/// * `error_message` - An optional error message. +#[repr(C)] +pub struct EmptyReturn { + pub is_error: c_int, // 0 for success, 1 for error + pub error_message: *mut c_char, // Optional error message +} + +impl EmptyReturn { + + /// Returns a new `EmptyReturn` object with no error. + /// + /// # Returns + /// A new `EmptyReturn` object. + pub fn success() -> Self { + EmptyReturn { + is_error: 0, + error_message: std::ptr::null_mut() + } + } +} + + +/// Frees the memory allocated for the `EmptyReturn` object. +/// +/// # Arguments +/// * `empty_return` - The `EmptyReturn` object to free. +#[no_mangle] +pub extern "C" fn free_empty_return(empty_return: EmptyReturn) { + // Free the error message if it is not null + if !empty_return.error_message.is_null() { + unsafe { drop(CString::from_raw(empty_return.error_message)) }; + } +} + + +/// Returns a vector of bytes to the caller. +/// +/// # Fields +/// * `data` - The pointer to the data. +/// * `length` - The length of the data. +/// * `capacity` - The capacity of the data. +/// * `is_error` - A flag indicating if an error occurred (1 if error 0 if not). +/// * `error_message` - An optional error message. +#[repr(C)] +pub struct VecU8Return { + pub data: *mut u8, + pub length: usize, + pub capacity: usize, // Optional if you want to include capacity for clarity + pub is_error: c_int, + pub error_message: *mut c_char +} + + +impl VecU8Return { + + /// Returns a new `VecU8Return` object with the data and no error. + /// + /// # Arguments + /// * `data` - The data to return. + /// + /// # Returns + /// A new `VecU8Return` object. + pub fn success(data: Vec) -> Self { + let mut data = data; + let data_ptr = data.as_mut_ptr(); + let length = data.len(); + let capacity = data.capacity(); + std::mem::forget(data); + VecU8Return { + data: data_ptr, + length, + capacity, + is_error: 0, + error_message: std::ptr::null_mut() + } + } +} + + +/// Frees the memory allocated for the `VecU8Return` object. +/// +/// # Arguments +/// * `vec_u8` - The `VecU8Return` object to free. +#[no_mangle] +pub extern "C" fn free_vec_u8(vec_u8: VecU8Return) { + // Free the data if it is not null + if !vec_u8.data.is_null() { + unsafe { drop(Vec::from_raw_parts(vec_u8.data, vec_u8.length, vec_u8.capacity)) }; + } +} + + +/// Holds the data around the outcome of the raw_compute function. +/// +/// # Fields +/// * `data` - The data returned from the computation. +/// * `length` - The length of the data. +/// * `capacity` - The capacity of the data. +/// * `is_error` - A flag indicating if an error occurred (1 for error, 0 for success). +/// * `error_message` - An error message if the computation failed. +#[repr(C)] +pub struct Vecf32Return { + pub data: *mut f32, + pub length: usize, + pub capacity: usize, // Optional if you want to include capacity for clarity + pub is_error: c_int, + pub error_message: *mut c_char +} + + +/// Frees the memory allocated for the Vecf32Return. +/// +/// # Arguments +/// * `vecf32_return` - The Vecf32Return to free. +#[no_mangle] +pub extern "C" fn free_vecf32_return(vecf32_return: Vecf32Return) { + // Free the data if it is not null + if !vecf32_return.data.is_null() { + unsafe { drop(Vec::from_raw_parts(vecf32_return.data, vecf32_return.length, vecf32_return.capacity)) }; + } + // Free the error message if it is not null + if !vecf32_return.error_message.is_null() { + unsafe { drop(CString::from_raw(vecf32_return.error_message)) }; + } +} \ No newline at end of file diff --git a/modules/c-wrapper/tests/__init__.py b/modules/c-wrapper/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/tests/api/__init__.py b/modules/c-wrapper/tests/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/tests/api/execution/__init__.py b/modules/c-wrapper/tests/api/execution/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/tests/api/execution/test_buffered_compute.py b/modules/c-wrapper/tests/api/execution/test_buffered_compute.py new file mode 100644 index 0000000..c940747 --- /dev/null +++ b/modules/c-wrapper/tests/api/execution/test_buffered_compute.py @@ -0,0 +1,80 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo, Vecf32Return +from test_utils.routes import TEST_SURML_PATH + + +class TestExecution(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + + # Define the Rust function signatures + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.load_model.restype = FileInfo + + self.lib.free_file_info.argtypes = [FileInfo] + + self.lib.buffered_compute.argtypes = [ + ctypes.c_char_p, # file_id_ptr -> *const c_char + ctypes.POINTER(ctypes.c_float), # data_ptr -> *const c_float + ctypes.c_size_t, # data_length -> usize + ctypes.POINTER(ctypes.c_char_p), # strings -> *const *const c_char + ctypes.c_int # string_count -> c_int + ] + self.lib.buffered_compute.restype = Vecf32Return + + self.lib.free_vecf32_return.argtypes = [Vecf32Return] + + def test_buffered_compute(self): + # Load a test model + c_string = str(TEST_SURML_PATH).encode('utf-8') + file_info = self.lib.load_model(c_string) + + if file_info.error_message: + self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}") + + input_data = { + "squarefoot": 500.0, + "num_floors": 2.0 + } + + string_buffer = [] + data_buffer = [] + for key, value in input_data.items(): + string_buffer.append(key.encode('utf-8')) + data_buffer.append(value) + + # Prepare input data as a ctypes array + array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size + input_data = array_type(*data_buffer) # Instantiate the array with the list elements + + # prepare the input strings + string_array = (ctypes.c_char_p * len(string_buffer))(*string_buffer) + string_count = len(string_buffer) + + # Call the raw_compute function + result = self.lib.buffered_compute( + file_info.file_id, + input_data, + len(input_data), + string_array, + string_count + ) + + if result.is_error: + self.fail(f"Error in buffered_compute: {result.error_message.decode('utf-8')}") + + # Extract and verify the computation result + outcome = [result.data[i] for i in range(result.length)] + self.assertEqual(362.9851989746094, outcome[0]) + + # Free allocated memory + self.lib.free_vecf32_return(result) + self.lib.free_file_info(file_info) + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/tests/api/execution/test_raw_compute.py b/modules/c-wrapper/tests/api/execution/test_raw_compute.py new file mode 100644 index 0000000..30dba87 --- /dev/null +++ b/modules/c-wrapper/tests/api/execution/test_raw_compute.py @@ -0,0 +1,54 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo, Vecf32Return +from test_utils.routes import TEST_SURML_PATH + + +class TestExecution(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + + # Define the Rust function signatures + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.load_model.restype = FileInfo + + self.lib.free_file_info.argtypes = [FileInfo] + + self.lib.raw_compute.argtypes = [ctypes.c_char_p, ctypes.POINTER(ctypes.c_float), ctypes.c_size_t] + self.lib.raw_compute.restype = Vecf32Return + + self.lib.free_vecf32_return.argtypes = [Vecf32Return] + + def test_raw_compute(self): + # Load a test model + c_string = str(TEST_SURML_PATH).encode('utf-8') + file_info = self.lib.load_model(c_string) + + if file_info.error_message: + self.fail(f"Failed to load model: {file_info.error_message.decode('utf-8')}") + + # Prepare input data as a ctypes array + data_buffer = [1.0, 4.0] + array_type = ctypes.c_float * len(data_buffer) # Create an array type of the appropriate size + input_data = array_type(*data_buffer) # Instantiate the array with the list elements + + # Call the raw_compute function + result = self.lib.raw_compute(file_info.file_id, input_data, len(input_data)) + + if result.is_error: + self.fail(f"Error in raw_compute: {result.error_message.decode('utf-8')}") + + # Extract and verify the computation result + outcome = [result.data[i] for i in range(result.length)] + self.assertEqual(1.8246129751205444, outcome[0]) + + # Free allocated memory + self.lib.free_vecf32_return(result) + self.lib.free_file_info(file_info) + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/tests/api/storage/__init__.py b/modules/c-wrapper/tests/api/storage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/tests/api/storage/test_load_cached_raw_model.py b/modules/c-wrapper/tests/api/storage/test_load_cached_raw_model.py new file mode 100644 index 0000000..8ce31d7 --- /dev/null +++ b/modules/c-wrapper/tests/api/storage/test_load_cached_raw_model.py @@ -0,0 +1,46 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import StringReturn +from test_utils.routes import SHOULD_BREAK_FILE, TEST_ONNX_FILE_PATH + + +class TestLoadCachedRawModel(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + # define the types + self.lib.load_cached_raw_model.restype = StringReturn + self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p] + + def test_null_pointer_protection(self): + null_pointer = None + outcome: StringReturn = self.lib.load_cached_raw_model(null_pointer) + self.assertEqual(1, outcome.is_error) + self.assertEqual("Received a null pointer for file path", outcome.error_message.decode('utf-8')) + + def test_wrong_path(self): + wrong_path = "should_break".encode('utf-8') + outcome: StringReturn = self.lib.load_cached_raw_model(wrong_path) + self.assertEqual(1, outcome.is_error) + self.assertEqual( + "No such file or directory (os error 2)", + outcome.error_message.decode('utf-8') + ) + + def test_wrong_file_format(self): + wrong_file_type = str(SHOULD_BREAK_FILE).encode('utf-8') + outcome: StringReturn = self.lib.load_cached_raw_model(wrong_file_type) + # below is unexpected and also happens in the old API + # TODO => throw an error if the file format is incorrect + self.assertEqual(0, outcome.is_error) + + def test_success(self): + right_file = str(TEST_ONNX_FILE_PATH).encode('utf-8') + outcome: StringReturn = self.lib.load_cached_raw_model(right_file) + self.assertEqual(0, outcome.is_error) + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/tests/api/storage/test_load_model.py b/modules/c-wrapper/tests/api/storage/test_load_model.py new file mode 100644 index 0000000..0c31503 --- /dev/null +++ b/modules/c-wrapper/tests/api/storage/test_load_model.py @@ -0,0 +1,39 @@ +import ctypes +from unittest import TestCase, main + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import FileInfo +from test_utils.routes import SHOULD_BREAK_FILE, TEST_SURML_PATH + + +class TestLoadModel(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + self.lib.load_model.restype = FileInfo + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.free_file_info.argtypes = [FileInfo] + + def test_null_pointer_protection(self): + null_pointer = None + outcome: FileInfo = self.lib.load_model(null_pointer) + self.assertEqual(1, outcome.is_error) + self.assertEqual("Received a null pointer for file path", outcome.error_message.decode('utf-8')) + + def test_wrong_file(self): + wrong_file_type = str(SHOULD_BREAK_FILE).encode('utf-8') + outcome: FileInfo = self.lib.load_model(wrong_file_type) + self.assertEqual(1, outcome.is_error) + self.assertEqual(True, "failed to fill whole buffer" in outcome.error_message.decode('utf-8')) + + def test_success(self): + surml_file_path = str(TEST_SURML_PATH).encode('utf-8') + outcome: FileInfo = self.lib.load_model(surml_file_path) + self.assertEqual(0, outcome.is_error) + self.lib.free_file_info(outcome) + + + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/tests/api/storage/test_meta.py b/modules/c-wrapper/tests/api/storage/test_meta.py new file mode 100644 index 0000000..c4abf91 --- /dev/null +++ b/modules/c-wrapper/tests/api/storage/test_meta.py @@ -0,0 +1,142 @@ +""" +Tests all the meta data functions +""" +import ctypes +from unittest import TestCase, main +from typing import Optional +import os + +from test_utils.c_lib_loader import load_library +from test_utils.return_structs import EmptyReturn, FileInfo, StringReturn +from test_utils.routes import TEST_SURML_PATH, TEST_ONNX_FILE_PATH, ASSETS_PATH + + +class TestMeta(TestCase): + + def setUp(self) -> None: + self.lib = load_library() + self.lib.add_name.restype = EmptyReturn + + # Define the signatues of the basic meta functions + self.functions = [ + self.lib.add_name, + self.lib.add_description, + self.lib.add_version, + self.lib.add_column, + self.lib.add_author, + self.lib.add_origin, + self.lib.add_engine, + ] + for i in self.functions: + i.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + i.restype = EmptyReturn + + # Define the load model signature + self.lib.load_model.restype = FileInfo + self.lib.load_model.argtypes = [ctypes.c_char_p] + self.lib.free_file_info.argtypes = [FileInfo] + # define the load raw model signature + self.lib.load_cached_raw_model.restype = StringReturn + self.lib.load_cached_raw_model.argtypes = [ctypes.c_char_p] + # define the save model signature + self.lib.save_model.restype = EmptyReturn + self.lib.save_model.argtypes = [ctypes.c_char_p, ctypes.c_char_p] + # load the model for tests + self.model: FileInfo = self.lib.load_model(str(TEST_SURML_PATH).encode('utf-8')) + self.file_id = self.model.file_id.decode('utf-8') + self.temp_test_id: Optional[str] = None + + def tearDown(self) -> None: + self.lib.free_file_info(self.model) + + # remove the temp surml file created in assets if present + if self.test_temp_surml_file_path is not None: + os.remove(self.test_temp_surml_file_path) + + def test_null_protection(self): + placeholder = "placeholder".encode('utf-8') + file_id = self.file_id.encode('utf-8') + + # check that they all protect against file ID null pointers + for i in self.functions: + outcome: EmptyReturn = i(None, placeholder) + self.assertEqual(1, outcome.is_error) + self.assertEqual( + "Received a null pointer for file id", + outcome.error_message.decode('utf-8') + ) + + # check that they all protect against null pointers for the field type + outcomes = [ + "model name", + "description", + "version", + "column name", + "author", + "origin", + "engine", + ] + counter = 0 + for i in self.functions: + outcome: EmptyReturn = i(file_id, None) + self.assertEqual(1, outcome.is_error) + self.assertEqual( + f"Received a null pointer for {outcomes[counter]}", + outcome.error_message.decode('utf-8') + ) + counter += 1 + + def test_model_not_found(self): + placeholder = "placeholder".encode('utf-8') + + # check they all return errors if not found + for i in self.functions: + outcome: EmptyReturn = i(placeholder, placeholder) + self.assertEqual(1, outcome.is_error) + self.assertEqual("Model not found", outcome.error_message.decode('utf-8')) + + def test_add_metadata_and_save(self): + file_id: StringReturn = self.lib.load_cached_raw_model(str(TEST_SURML_PATH).encode('utf-8')) + self.assertEqual(0, file_id.is_error) + + decoded_file_id = file_id.string.decode('utf-8') + self.temp_test_id = decoded_file_id + + self.assertEqual( + 0, + self.lib.add_name(file_id.string, "test name".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_description(file_id.string, "test description".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_version(file_id.string, "0.0.1".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.add_author(file_id.string, "test author".encode('utf-8')).is_error + ) + self.assertEqual( + 0, + self.lib.save_model(self.test_temp_surml_file_path.encode("utf-8"), file_id.string).is_error + ) + + outcome: FileInfo = self.lib.load_model(self.test_temp_surml_file_path.encode('utf-8')) + self.assertEqual(0, outcome.is_error) + self.assertEqual("test name", outcome.name.decode('utf-8')) + self.assertEqual("test description", outcome.description.decode('utf-8')) + self.assertEqual("0.0.1", outcome.version.decode('utf-8')) + + + @property + def test_temp_surml_file_path(self) -> Optional[str]: + if self.temp_test_id is None: + return None + return str(ASSETS_PATH.joinpath(f"{self.temp_test_id}.surml")) + + + +if __name__ == '__main__': + main() diff --git a/modules/c-wrapper/tests/test_utils/__init__.py b/modules/c-wrapper/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/tests/test_utils/assets/linear_test.onnx b/modules/c-wrapper/tests/test_utils/assets/linear_test.onnx new file mode 100644 index 0000000..f7b070b --- /dev/null +++ b/modules/c-wrapper/tests/test_utils/assets/linear_test.onnx @@ -0,0 +1,15 @@ +pytorch2.0.1:… +Q +onnx::MatMul_0 +onnx::MatMul_6/linear/MatMul_output_0/linear/MatMul"MatMul +; + linear.bias +/linear/MatMul_output_05 /linear/Add"Add torch_jit*B linear.biasJdÕ ²* Bonnx::MatMul_6J ‘9?[ÄŒ>Z +onnx::MatMul_0 + + +b +5 + + +B \ No newline at end of file diff --git a/modules/c-wrapper/tests/test_utils/assets/should_break.txt b/modules/c-wrapper/tests/test_utils/assets/should_break.txt new file mode 100644 index 0000000..e69de29 diff --git a/modules/c-wrapper/tests/test_utils/assets/test.surml b/modules/c-wrapper/tests/test_utils/assets/test.surml new file mode 100644 index 0000000..61da29a Binary files /dev/null and b/modules/c-wrapper/tests/test_utils/assets/test.surml differ diff --git a/modules/c-wrapper/tests/test_utils/c_lib_loader.py b/modules/c-wrapper/tests/test_utils/c_lib_loader.py new file mode 100644 index 0000000..00adfc6 --- /dev/null +++ b/modules/c-wrapper/tests/test_utils/c_lib_loader.py @@ -0,0 +1,56 @@ +import ctypes +import platform +from pathlib import Path +import os +from test_utils.return_structs import EmptyReturn + + +def load_library(lib_name: str = "libc_wrapper") -> ctypes.CDLL: + """ + Load the correct shared library based on the operating system. + + Args: + lib_name (str): The base name of the library without extension (e.g., "libc_wrapper"). + + Returns: + ctypes.CDLL: The loaded shared library. + """ + current_dir = Path(__file__).parent + system_name = platform.system() + + # os.environ["ORT_LIB_LOCATION"] = str(current_dir.joinpath("onnxruntime.dll")) + + if system_name == "Windows": + lib_path = current_dir.joinpath(f"{lib_name}.dll") + onnx_path = current_dir.joinpath("onnxruntime").joinpath("lib").joinpath("onnxruntime.dll") + elif system_name == "Darwin": # macOS + lib_path = current_dir.joinpath(f"{lib_name}.dylib") + onnx_path = current_dir.joinpath("onnxruntime").joinpath("lib").joinpath("onnxruntime.dylib") + elif system_name == "Linux": + lib_path = current_dir.joinpath(f"{lib_name}.so") + onnx_path = current_dir.joinpath("onnxruntime").joinpath("lib").joinpath("onnxruntime.so.1") + else: + raise OSError(f"Unsupported operating system: {system_name}") + + + # onnx_lib_path = current_dir.joinpath("onnxruntime").joinpath("lib") + # current_ld_library_path = os.environ.get("LD_LIBRARY_PATH", "") + # # Update LD_LIBRARY_PATH + # os.environ["LD_LIBRARY_PATH"] = f"{onnx_lib_path}:{current_ld_library_path}" + # os.environ["ORT_LIB_LOCATION"] = str(onnx_lib_path) + + # ctypes.CDLL(str(onnx_path), mode=ctypes.RTLD_GLOBAL) + onnx_path = current_dir.joinpath("onnxruntime") + + if not lib_path.exists(): + raise FileNotFoundError(f"Shared library not found at: {lib_path}") + + loaded_lib = ctypes.CDLL(str(lib_path)) + loaded_lib.link_onnx.argtypes = [ctypes.c_char_p] + loaded_lib.link_onnx.restype = EmptyReturn + c_string = str(onnx_path).encode('utf-8') + load_info = loaded_lib.link_onnx(c_string) + if load_info.error_message: + raise OSError(f"Failed to load onnxruntime: {load_info.error_message.decode('utf-8')}") + + return ctypes.CDLL(str(lib_path)) diff --git a/modules/c-wrapper/tests/test_utils/return_structs.py b/modules/c-wrapper/tests/test_utils/return_structs.py new file mode 100644 index 0000000..3baaa35 --- /dev/null +++ b/modules/c-wrapper/tests/test_utils/return_structs.py @@ -0,0 +1,64 @@ +""" +Defines all the C structs that are used in the tests. +""" +from ctypes import Structure, c_char_p, c_int, c_size_t, POINTER, c_float + + +class StringReturn(Structure): + """ + A return type that just returns a string + + Fields: + string: the string that is being returned (only present if successful) + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("string", c_char_p), # Corresponds to *mut c_char + ("is_error", c_int), # Corresponds to c_int + ("error_message", c_char_p) # Corresponds to *mut c_char + ] + +class EmptyReturn(Structure): + """ + A return type that just returns nothing + + Fields: + is_error: 1 if error, 0 if not + error_message: the error message (only present if error) + """ + _fields_ = [ + ("is_error", c_int), # Corresponds to c_int + ("error_message", c_char_p) # Corresponds to *mut c_char + ] + + +class FileInfo(Structure): + """ + A return type when loading the meta of a surml file. + + Fields: + file_id: a unique identifier for the file in the state of the C lib + name: a name of the model + description: a description of the model + error_message: the error message (only present if error) + is_error: 1 if error, 0 if not + """ + _fields_ = [ + ("file_id", c_char_p), # Corresponds to *mut c_char + ("name", c_char_p), # Corresponds to *mut c_char + ("description", c_char_p), # Corresponds to *mut c_char + ("version", c_char_p), # Corresponds to *mut c_char + ("error_message", c_char_p), # Corresponds to *mut c_char + ("is_error", c_int) # Corresponds to c_int + ] + + +class Vecf32Return(Structure): + _fields_ = [ + ("data", POINTER(c_float)), # Pointer to f32 array + ("length", c_size_t), # Length of the array + ("capacity", c_size_t), # Capacity of the array + ("is_error", c_int), # Indicates if it's an error + ("error_message", c_char_p), # Optional error message + ] diff --git a/modules/c-wrapper/tests/test_utils/routes.py b/modules/c-wrapper/tests/test_utils/routes.py new file mode 100644 index 0000000..07e07e7 --- /dev/null +++ b/modules/c-wrapper/tests/test_utils/routes.py @@ -0,0 +1,12 @@ +""" +Defines all the routes for the testing module to all the assets and C libs +""" +from pathlib import Path + + +UTILS_PATH = Path(__file__).parent +ASSETS_PATH = UTILS_PATH.joinpath("assets") +TEST_SURML_PATH = ASSETS_PATH.joinpath("test.surml") +SHOULD_BREAK_FILE = ASSETS_PATH.joinpath("should_break.txt") +TEST_ONNX_FILE_PATH = ASSETS_PATH.joinpath("linear_test.onnx") +ONNX_LIB = UTILS_PATH.joinpath("..").joinpath("..").joinpath("onnx_lib").joinpath("onnxruntime") diff --git a/modules/core/Cargo.toml b/modules/core/Cargo.toml index 82cb07e..8d2dac1 100644 --- a/modules/core/Cargo.toml +++ b/modules/core/Cargo.toml @@ -17,16 +17,18 @@ sklearn-tests = [] onnx-tests = [] torch-tests = [] tensorflow-tests = [] +gpu = [] +dynamic = ["ort/load-dynamic"] [dependencies] regex = "1.9.3" -ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false } -ndarray = "0.15.6" +ort = { version = "2.0.0-rc.9", features = [ "cuda", "ndarray" ]} +ndarray = "0.16.1" once_cell = "1.18.0" bytes = "1.5.0" futures-util = "0.3.28" futures-core = "0.3.28" -thiserror = "1.0.57" +thiserror = "2.0.9" serde = { version = "1.0.197", features = ["derive"] } axum = { version = "0.7.4", optional = true } actix-web = { version = "4.5.1", optional = true } @@ -39,5 +41,5 @@ tokio = { version = "1.12.0", features = ["full"] } name = "surrealml_core" path = "src/lib.rs" -[build-dependencies] -ort = { version = "1.16.2", default-features = true } +# [build-dependencies] +# ort = { version = "1.16.2", default-features = true } diff --git a/modules/core/Dockerfile b/modules/core/Dockerfile new file mode 100644 index 0000000..c4f15ef --- /dev/null +++ b/modules/core/Dockerfile @@ -0,0 +1,37 @@ +# Use an official Rust image +FROM rust:1.83-slim + +# Install necessary tools +RUN apt-get update && apt-get install -y \ + wget \ + build-essential \ + libssl-dev \ + pkg-config \ + && rm -rf /var/lib/apt/lists/* + +# Set the working directory +WORKDIR /app + +# Copy the project files into the container +COPY . . + +# Download ONNX Runtime 1.20.0 +RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.20.0/onnxruntime-linux-x64-1.20.0.tgz \ + && tar -xvf onnxruntime-linux-x64-1.20.0.tgz \ + && mv onnxruntime-linux-x64-1.20.0 /onnxruntime + +# # Download ONNX Runtime 1.16.0 +# RUN wget https://github.com/microsoft/onnxruntime/releases/download/v1.16.0/onnxruntime-linux-x64-1.16.0.tgz \ +# && tar -xvf onnxruntime-linux-x64-1.16.0.tgz \ +# && mv onnxruntime-linux-x64-1.16.0 /onnxruntime + +# Set the ONNX Runtime library path +ENV ORT_LIB_LOCATION=/onnxruntime/lib +ENV LD_LIBRARY_PATH=$ORT_LIB_LOCATION:$LD_LIBRARY_PATH + +# Clean and build the Rust project +RUN cargo clean +RUN cargo build --features tensorflow-tests + +# Run the tests +CMD ["cargo", "test", "--features", "tensorflow-tests"] diff --git a/modules/core/build.rs b/modules/core/build.rs index 53670b5..a1fd6aa 100644 --- a/modules/core/build.rs +++ b/modules/core/build.rs @@ -1,109 +1,4 @@ -use std::env; -use std::fs; -use std::path::Path; -/// works out where the `onnxruntime` library is in the build target and copies the library to the root -/// of the crate so the core library can find it and load it into the binary using `include_bytes!()`. -/// -/// # Notes -/// This is a workaround for the fact that `onnxruntime` doesn't support `cargo` yet. This build step -/// is reliant on the `ort` crate downloading and building the `onnxruntime` library. This is -/// why the following dependency is required in `Cargo.toml`: -/// ```toml -/// [build-dependencies] -/// ort = { version = "1.16.2", default-features = true } -/// ``` -/// Here we can see that the `default-features` is set to `true`. This is because the `ort` crate will download -/// the correct package and build it for the target platform by default. In the main part of our dependencies -/// we have the following: -/// ```toml -/// [dependencies] -/// ort = { version = "1.16.2", features = ["load-dynamic"], default-features = false } -/// ``` -/// Here we can see that the `default-features` is set to `false`. This is because we don't want the `ort` crate -/// to download and build the `onnxruntime` library again. Instead we want to use the one that was built in the -/// build step. We also set the `load-dynamic` feature to `true` so that the `ort` crate will load the `onnxruntime` -/// library dynamically at runtime. This is because we don't want to statically link the `onnxruntime`. Our `onnxruntime` -/// is embedded into the binary using `include_bytes!()` and we want to load it dynamically at runtime. This means that -/// we do not need to move the `onnxruntime` library around with the binary, and there is no complicated setup required -/// or linking. -fn unpack_onnx() -> std::io::Result<()> { - let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); - let out_path = Path::new(&out_dir); - let build_dir = out_path - .ancestors() // This gives an iterator over all ancestors of the path - .nth(3) // 'nth(3)' gets the fourth ancestor (counting from 0), which should be the debug directory - .expect("Failed to find debug directory"); - - match std::env::var("ONNXRUNTIME_LIB_PATH") { - Ok(onnx_path) => { - println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH set at: {}", onnx_path); - println!("cargo:rustc-cfg=onnx_runtime_env_var_set"); - } - Err(_) => { - println!("Surrealml Core Debug: ONNXRUNTIME_LIB_PATH not set"); - let target_lib = match env::var("CARGO_CFG_TARGET_OS").unwrap() { - ref s if s.contains("linux") => "libonnxruntime.so", - ref s if s.contains("macos") => "libonnxruntime.dylib", - ref s if s.contains("windows") => "onnxruntime.dll", - // ref s if s.contains("android") => "android", => not building for android - _ => panic!("Unsupported target os"), - }; - - let lib_path = build_dir.join(target_lib); - let lib_path = lib_path.to_str().unwrap(); - println!("Surrealml Core Debug: lib_path={}", lib_path); - - // Check if the path exists - if fs::metadata(lib_path).is_ok() { - println!("Surrealml Core Debug: lib_path exists"); - } else { - println!("Surrealml Core Debug: lib_path does not exist"); - // Extract the directory path - if let Some(parent) = std::path::Path::new(lib_path).parent() { - // Print the contents of the directory - match fs::read_dir(parent) { - Ok(entries) => { - println!("Surrealml Core Debug: content of directory {}", parent.display()); - for entry in entries { - if let Ok(entry) = entry { - println!("{}", entry.path().display()); - } - } - } - Err(e) => { - println!("Surrealml Core Debug: Failed to read directory {}: {}", parent.display(), e); - } - } - } else { - println!("Surrealml Core Debug: Could not determine the parent directory of the path."); - } - } - - // put it next to the file of the embedding - let destination = Path::new(target_lib); - fs::copy(lib_path, destination)?; - println!("Surrealml Core Debug: onnx lib copied from {} to {}", lib_path, destination.display()); - } - } - Ok(()) +fn main() { + } - -fn main() -> std::io::Result<()> { - if std::env::var("DOCS_RS").is_ok() { - // we are not going to be anything here for docs.rs, because we are merely building the docs. When we are just building - // the docs, the onnx environment variable will not look for the `onnxruntime` library, so we don't need to unpack it. - return Ok(()); - } - - if env::var("ORT_STRATEGY").as_deref() == Ok("system") { - // If the ORT crate is built with the `system` strategy, then the crate will take care of statically linking the library. - // No need to do anything here. - println!("cargo:rustc-cfg=onnx_statically_linked"); - - return Ok(()); - } - - unpack_onnx()?; - Ok(()) -} \ No newline at end of file diff --git a/modules/core/onnxruntime-linux-x64-1.20.0.tgz b/modules/core/onnxruntime-linux-x64-1.20.0.tgz new file mode 100644 index 0000000..2cacda3 Binary files /dev/null and b/modules/core/onnxruntime-linux-x64-1.20.0.tgz differ diff --git a/modules/core/scripts/install_onnxruntime_linux.sh b/modules/core/scripts/install_onnxruntime_linux.sh new file mode 100644 index 0000000..c76605c --- /dev/null +++ b/modules/core/scripts/install_onnxruntime_linux.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +# Variables +ONNX_VERSION="1.20.0" +ONNX_DOWNLOAD_URL="https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz" +ONNX_RUNTIME_DIR="/home/maxwellflitton/Documents/github/surreal/surrealml/modules/core/target/debug/build/ort-680c63907dcb00d8/out/onnxruntime" +ONNX_TARGET_DIR="${ONNX_RUNTIME_DIR}/onnxruntime-linux-x64-${ONNX_VERSION}" +LD_LIBRARY_PATH_UPDATE="${ONNX_TARGET_DIR}/lib" + +# Step 1: Download and Extract ONNX Runtime +echo "Downloading ONNX Runtime version ${ONNX_VERSION}..." +wget -q --show-progress "${ONNX_DOWNLOAD_URL}" -O "onnxruntime-linux-x64-${ONNX_VERSION}.tgz" + +if [ $? -ne 0 ]; then + echo "Failed to download ONNX Runtime. Exiting." + exit 1 +fi + +echo "Extracting ONNX Runtime..." +tar -xvf "onnxruntime-linux-x64-${ONNX_VERSION}.tgz" + +if [ ! -d "onnxruntime-linux-x64-${ONNX_VERSION}" ]; then + echo "Extraction failed. Directory not found. Exiting." + exit 1 +fi + +# Step 2: Replace Old ONNX Runtime +echo "Replacing old ONNX Runtime..." +mkdir -p "${ONNX_RUNTIME_DIR}" +mv "onnxruntime-linux-x64-${ONNX_VERSION}" "${ONNX_TARGET_DIR}" + +if [ ! -d "${ONNX_TARGET_DIR}" ]; then + echo "Failed to move ONNX Runtime to target directory. Exiting." + exit 1 +fi + +# Step 3: Update LD_LIBRARY_PATH +echo "Updating LD_LIBRARY_PATH..." +export LD_LIBRARY_PATH="${LD_LIBRARY_PATH_UPDATE}:$LD_LIBRARY_PATH" + +# Step 4: Verify Library Version +echo "Verifying ONNX Runtime version..." +strings "${LD_LIBRARY_PATH_UPDATE}/libonnxruntime.so" | grep "VERS_${ONNX_VERSION}" > /dev/null + +if [ $? -ne 0 ]; then + echo "ONNX Runtime version ${ONNX_VERSION} not found in library. Exiting." + exit 1 +fi + +# Step 5: Install Library Globally (Optional) +echo "Installing ONNX Runtime globally..." +sudo cp "${LD_LIBRARY_PATH_UPDATE}/libonnxruntime.so" /usr/local/lib/ +sudo ldconfig + +if [ $? -ne 0 ]; then + echo "Failed to install ONNX Runtime globally. Exiting." + exit 1 +fi + +# Step 6: Clean and Rebuild Project +echo "Cleaning and rebuilding project..." +cargo clean +cargo test --features tensorflow-tests + +if [ $? -eq 0 ]; then + echo "ONNX Runtime updated successfully, and tests passed." +else + echo "ONNX Runtime updated, but tests failed. Check the logs for details." +fi diff --git a/modules/core/src/execution/compute.rs b/modules/core/src/execution/compute.rs index 6c796a3..584575e 100644 --- a/modules/core/src/execution/compute.rs +++ b/modules/core/src/execution/compute.rs @@ -1,12 +1,13 @@ //! Defines the operations around performing computations on a loaded model. use crate::storage::surml_file::SurMlFile; use std::collections::HashMap; -use ndarray::{ArrayD, CowArray}; -use ort::{SessionBuilder, Value, session::Input}; +use ndarray::ArrayD; +use ort::value::ValueType; +use ort::session::Session; -use super::onnx_environment::ENVIRONMENT; use crate::safe_eject; use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use crate::execution::session::get_session; /// A wrapper for the loaded machine learning model so we can perform computations on the loaded model. @@ -39,15 +40,21 @@ impl <'a>ModelComputation<'a> { /// /// # Returns /// A vector of dimensions for the input tensor to be reshaped into from the loaded model. - fn process_input_dims(input_dims: &Input) -> Vec { - let mut buffer = Vec::new(); - for dim in input_dims.dimensions() { - match dim { - Some(dim) => buffer.push(dim as usize), - None => buffer.push(1) + fn process_input_dims(session_ref: &Session) -> Vec { + let some_dims = match &session_ref.inputs[0].input_type { + ValueType::Tensor{ ty: _, dimensions: new_dims, dimension_symbols: _ } => Some(new_dims), + _ => None + }; + let mut dims_cache = Vec::new(); + for dim in some_dims.unwrap() { + if dim < &0 { + dims_cache.push((dim * -1) as usize); + } + else { + dims_cache.push(*dim as usize); } } - buffer + dims_cache } /// Creates a Vector that can be used manipulated with other operations such as normalisation from a hashmap of keys and values. @@ -79,26 +86,33 @@ impl <'a>ModelComputation<'a> { /// # Returns /// The computed output tensor from the loaded model. pub fn raw_compute(&self, tensor: ArrayD, _dims: Option<(i32, i32)>) -> Result, SurrealError> { - let session = safe_eject!(SessionBuilder::new(&ENVIRONMENT), SurrealErrorStatus::Unknown); - let session = safe_eject!(session.with_model_from_memory(&self.surml_file.model), SurrealErrorStatus::Unknown); - let unwrapped_dims = ModelComputation::process_input_dims(&session.inputs[0]); - let tensor = safe_eject!(tensor.into_shape(unwrapped_dims), SurrealErrorStatus::Unknown); - - let x = CowArray::from(tensor).into_dyn(); - let input_values = safe_eject!(Value::from_array(session.allocator(), &x), SurrealErrorStatus::Unknown); - let outputs = safe_eject!(session.run(vec![input_values]), SurrealErrorStatus::Unknown); + let session = get_session(self.surml_file.model.clone())?; + let dims_cache = ModelComputation::process_input_dims(&session); + let tensor = match tensor.into_shape_with_order(dims_cache) { + Ok(tensor) => tensor, + Err(_) => return Err(SurrealError::new("Failed to reshape tensor to input dimensions".to_string(), SurrealErrorStatus::Unknown)) + }; + let tensor = match ort::value::Tensor::from_array(tensor) { + Ok(tensor) => tensor, + Err(_) => return Err(SurrealError::new("Failed to convert tensor to ort tensor".to_string(), SurrealErrorStatus::Unknown)) + }; + let x = match ort::inputs![tensor] { + Ok(x) => x, + Err(_) => return Err(SurrealError::new("Failed to create input tensor".to_string(), SurrealErrorStatus::Unknown)) + }; + let outputs = safe_eject!(session.run(x), SurrealErrorStatus::Unknown); let mut buffer: Vec = Vec::new(); // extract the output tensor converting the values to f32 if they are i64 - match outputs[0].try_extract::() { + match outputs[0].try_extract_tensor::() { Ok(y) => { for i in y.view().clone().into_iter() { buffer.push(*i); } }, Err(_) => { - for i in safe_eject!(outputs[0].try_extract::(), SurrealErrorStatus::Unknown).view().clone().into_iter() { + for i in safe_eject!(outputs[0].try_extract_tensor::(), SurrealErrorStatus::Unknown).view().clone().into_iter() { buffer.push(*i as f32); } } diff --git a/modules/core/src/execution/mod.rs b/modules/core/src/execution/mod.rs index 39ebd06..5b9d66c 100644 --- a/modules/core/src/execution/mod.rs +++ b/modules/core/src/execution/mod.rs @@ -1,3 +1,4 @@ //! Defines operations around performing computations on a loaded model. pub mod compute; -pub mod onnx_environment; +// pub mod onnx_environment; +pub mod session; diff --git a/modules/core/src/execution/session.rs b/modules/core/src/execution/session.rs new file mode 100644 index 0000000..69a847b --- /dev/null +++ b/modules/core/src/execution/session.rs @@ -0,0 +1,58 @@ +//! Defines the session module for the execution module. +use ort::session::Session; +use crate::errors::error::{SurrealError, SurrealErrorStatus}; +use crate::safe_eject; + +#[cfg(feature = "dynamic")] +use once_cell::sync::Lazy; +#[cfg(feature = "dynamic")] +use ort::environment::{EnvironmentBuilder, Environment}; +#[cfg(feature = "dynamic")] +use std::sync::{Arc, Mutex}; + +use std::sync::LazyLock; + + +/// Creates a session for a model. +/// +/// # Arguments +/// * `model_bytes` - The model bytes (usually extracted fromt the surml file) +/// +/// # Returns +/// A session object. +pub fn get_session(model_bytes: Vec) -> Result { + let builder = safe_eject!(Session::builder(), SurrealErrorStatus::Unknown); + + #[cfg(feature = "gpu")] + { + let cuda = CUDAExecutionProvider::default(); + if let Err(e) = cuda.register(&builder) { + eprintln!("Failed to register CUDA: {:?}. Falling back to CPU.", e); + } else { + println!("CUDA registered successfully"); + } + } + let session: Session = safe_eject!(builder + .commit_from_memory(&model_bytes), SurrealErrorStatus::Unknown); + Ok(session) +} + + +// #[cfg(feature = "dynamic")] +// pub static ORT_ENV: LazyLock>>>> = LazyLock::new(|| Arc::new(Mutex::new(None))); + + +#[cfg(feature = "dynamic")] +pub fn set_environment(dylib_path: String) -> Result<(), SurrealError> { + + let outcome: EnvironmentBuilder = ort::init_from(dylib_path); + match outcome.commit() { + Ok(env) => { + // ORT_ENV.lock().unwrap().replace(env); + }, + Err(e) => { + return Err(SurrealError::new(e.to_string(), SurrealErrorStatus::Unknown)); + } + } + Ok(()) +} diff --git a/modules/core/src/lib.rs b/modules/core/src/lib.rs index 625d826..5ec0080 100644 --- a/modules/core/src/lib.rs +++ b/modules/core/src/lib.rs @@ -60,7 +60,7 @@ //! //! ### Executing models //! We you load a `surml` file, you can execute the model with the following code: -//! ```rust +//! ```no_run //! use surrealml_core::storage::surml_file::SurMlFile; //! use surrealml_core::execution::compute::ModelComputation; //! use ndarray::ArrayD; @@ -94,5 +94,5 @@ pub mod errors; /// Returns the version of the ONNX runtime that is used. pub fn onnx_runtime() -> &'static str { - "1.16.0" + "1.20.0" } diff --git a/requirements.txt b/requirements.txt index 0700417..235cc8a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -onnxruntime==1.17.3 -numpy==1.26.3 +onnxruntime==1.20.0 +numpy==2.2.1 diff --git a/scripts/run_ci_workflows/c_wrapper_unit.sh b/scripts/run_ci_workflows/c_wrapper_unit.sh new file mode 100644 index 0000000..31f1e6c --- /dev/null +++ b/scripts/run_ci_workflows/c_wrapper_unit.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd ../.. + + +act -W .github/workflows/c_wrapper_unit_tests.yml pull_request \ No newline at end of file diff --git a/scripts/run_ci_workflows/surrealml_core_onnx.sh b/scripts/run_ci_workflows/surrealml_core_onnx.sh new file mode 100644 index 0000000..03125b8 --- /dev/null +++ b/scripts/run_ci_workflows/surrealml_core_onnx.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd ../.. + + +act -W .github/workflows/surrealml_core_onnx_test.yml pull_request \ No newline at end of file diff --git a/scripts/run_ci_workflows/surrealml_core_sklearn.sh b/scripts/run_ci_workflows/surrealml_core_sklearn.sh new file mode 100644 index 0000000..ac16761 --- /dev/null +++ b/scripts/run_ci_workflows/surrealml_core_sklearn.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd ../.. + + +act -W .github/workflows/surrealml_core_test.yml pull_request diff --git a/scripts/run_ci_workflows/surrealml_core_tensorflow.sh b/scripts/run_ci_workflows/surrealml_core_tensorflow.sh new file mode 100644 index 0000000..d356b5b --- /dev/null +++ b/scripts/run_ci_workflows/surrealml_core_tensorflow.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd ../.. + + +act -W .github/workflows/surrealml_core_tensorflow_test.yml pull_request \ No newline at end of file diff --git a/scripts/run_ci_workflows/surrealml_core_torch.sh b/scripts/run_ci_workflows/surrealml_core_torch.sh new file mode 100644 index 0000000..bb5a1ff --- /dev/null +++ b/scripts/run_ci_workflows/surrealml_core_torch.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# navigate to directory +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +cd $SCRIPTPATH + +cd ../.. + + +act -W .github/workflows/surrealml_core_torch_test.yml pull_request diff --git a/tests/model_builder/tensorflow_assets.py b/tests/model_builder/tensorflow_assets.py index 3b30e7b..2b81793 100644 --- a/tests/model_builder/tensorflow_assets.py +++ b/tests/model_builder/tensorflow_assets.py @@ -1,4 +1,4 @@ -from tests.model_builder.utils import install_package +from model_builder.utils import install_package install_package("tf2onnx==1.16.1") install_package("tensorflow==2.16.1") import os @@ -7,7 +7,7 @@ from surrealml.model_templates.tensorflow.tensorflow_linear import export_model_onnx as linear_tensorflow_export_model_onnx from surrealml.model_templates.tensorflow.tensorflow_linear import export_model_surml as linear_tensorflow_export_model_surml -from tests.model_builder.utils import delete_directory, create_directory, MODEL_STASH_DIRECTORY +from model_builder.utils import delete_directory, create_directory, MODEL_STASH_DIRECTORY # create the model stash directory if it does not exist