diff --git a/.gitignore b/.gitignore index 6117870..bc1a8fe 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,5 @@ surrealml/rust_surrealml.cpython-310-darwin.so ./modules/pipelines/runners/batch_training_runner/run_env/ ./modules/pipelines/data_access/target/ ./modules/pipelines/runners/integrated_training_runner/run_env/ -modules/pipelines/runners/integrated_training_runner/run_env/ \ No newline at end of file +modules/pipelines/runners/integrated_training_runner/run_env/ +modules/pipelines/data_access/target/ diff --git a/README.md b/README.md index 879d58a..1300ba4 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,12 @@ For `PyTorch`: pip install "git+https://github.com/surrealdb/surrealml#egg=surrealml[torch]" ``` +For `Tensorflow`: + +```bash +pip install "git+https://github.com/surrealdb/surrealml#egg=surrealml[tensorflow]" +``` + After that, you can train your model and save it in the SurrealML format. ## Compilation config diff --git a/modules/core/Cargo.toml b/modules/core/Cargo.toml index f781b65..210ed78 100644 --- a/modules/core/Cargo.toml +++ b/modules/core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "surrealml-core" -version = "0.1.1" +version = "0.1.2" edition = "2021" build = "./build.rs" description = "The core machine learning library for SurrealML that enables SurrealDB to store and load ML models" diff --git a/modules/core/model_stash/sklearn/surml/linear.surml b/modules/core/model_stash/sklearn/surml/linear.surml index 03e8041..f092b50 100644 Binary files a/modules/core/model_stash/sklearn/surml/linear.surml and b/modules/core/model_stash/sklearn/surml/linear.surml differ diff --git a/modules/core/model_stash/tensorflow/surml/linear.surml b/modules/core/model_stash/tensorflow/surml/linear.surml new file mode 100644 index 0000000..1670153 Binary files /dev/null and b/modules/core/model_stash/tensorflow/surml/linear.surml differ diff --git a/modules/core/model_stash/torch/surml/linear.surml b/modules/core/model_stash/torch/surml/linear.surml index 8094df3..ec68ffb 100644 Binary files a/modules/core/model_stash/torch/surml/linear.surml and b/modules/core/model_stash/torch/surml/linear.surml differ diff --git a/modules/core/src/execution/compute.rs b/modules/core/src/execution/compute.rs index e87447d..f7fb955 100644 --- a/modules/core/src/execution/compute.rs +++ b/modules/core/src/execution/compute.rs @@ -225,4 +225,36 @@ mod tests { let output = model_computation.buffered_compute(&mut input_values).unwrap(); assert_eq!(output.len(), 1); } + + #[test] + fn test_raw_compute_linear_tensorflow() { + let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let raw_input = model_computation.input_tensor_from_key_bindings(input_values).unwrap(); + + let output = model_computation.raw_compute(raw_input, None).unwrap(); + assert_eq!(output.len(), 1); + } + + #[test] + fn test_buffered_compute_linear_tensorflow() { + let mut file = SurMlFile::from_file("./model_stash/tensorflow/surml/linear.surml").unwrap(); + let model_computation = ModelComputation { + surml_file: &mut file, + }; + + let mut input_values = HashMap::new(); + input_values.insert(String::from("squarefoot"), 1000.0); + input_values.insert(String::from("num_floors"), 2.0); + + let output = model_computation.buffered_compute(&mut input_values).unwrap(); + assert_eq!(output.len(), 1); + } } diff --git a/requirements.txt b/requirements.txt index 924b1f0..ad5846f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -numpy==1.26.3 -skl2onnx==1.16.0 -scikit-learn==1.4.0 -torch==2.1.2 -onnx==1.15.0 -onnxruntime==1.16.3 \ No newline at end of file +numpy +skl2onnx +scikit-learn +torch +tf2onnx +tensorflow +onnxruntime diff --git a/setup.py b/setup.py index 6123ebc..39a44e8 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ "surrealml.model_templates.datasets", "surrealml.model_templates.sklearn", "surrealml.model_templates.torch", + "surrealml.model_templates.tensorflow", ], package_data={ "surrealml": ["binaries/*"], @@ -40,6 +41,10 @@ ], "torch": [ "torch==2.1.2" + ], + "tensorflow": [ + "tf2onnx==1.16.1", + "tensorflow==2.16.1" ] } ) diff --git a/surrealml/engine/__init__.py b/surrealml/engine/__init__.py index 7dfa704..9b3f16c 100644 --- a/surrealml/engine/__init__.py +++ b/surrealml/engine/__init__.py @@ -2,6 +2,7 @@ from surrealml.engine.sklearn import SklearnOnnxAdapter from surrealml.engine.torch import TorchOnnxAdapter +from surrealml.engine.tensorflow import TensorflowOnnxAdapter class Engine(Enum): @@ -12,7 +13,9 @@ class Engine(Enum): PYTORCH: The PyTorch engine which will be PyTorch and ONNX. NATIVE: The native engine which will be native rust and linfa. SKLEARN: The sklearn engine which will be sklearn and ONNX + TENSOFRLOW: The TensorFlow engine which will be TensorFlow and ONNX """ PYTORCH = "pytorch" NATIVE = "native" SKLEARN = "sklearn" + TENSORFLOW = "tensorflow" diff --git a/surrealml/engine/sklearn.py b/surrealml/engine/sklearn.py index ed5671b..3f9de69 100644 --- a/surrealml/engine/sklearn.py +++ b/surrealml/engine/sklearn.py @@ -28,11 +28,9 @@ def save_model_to_onnx(model, inputs) -> str: """ SklearnOnnxAdapter.check_dependency() file_path = create_file_cache_path() - # the below check is to satisfy type checkers - if skl2onnx is not None: - onnx = skl2onnx.to_onnx(model, inputs) + onnx = skl2onnx.to_onnx(model, inputs) - with open(file_path, "wb") as f: - f.write(onnx.SerializeToString()) + with open(file_path, "wb") as f: + f.write(onnx.SerializeToString()) - return file_path + return file_path diff --git a/surrealml/engine/tensorflow.py b/surrealml/engine/tensorflow.py new file mode 100644 index 0000000..05a2937 --- /dev/null +++ b/surrealml/engine/tensorflow.py @@ -0,0 +1,45 @@ +import os +import shutil +try: + import tf2onnx + import tensorflow as tf +except ImportError: + tf2onnx = None + tf = None + +from surrealml.engine.utils import TensorflowCache + + +class TensorflowOnnxAdapter: + + @staticmethod + def check_dependency() -> None: + """ + Checks if the tensorflow dependency is installed raising an error if not. + Please call this function when performing any tensorflow related operations. + """ + if tf2onnx is None or tf is None: + raise ImportError("tensorflow feature needs to be installed to use tensorflow features") + + @staticmethod + def save_model_to_onnx(model, inputs) -> str: + """ + Saves a tensorflow model to an onnx file. + + :param model: the tensorflow model to convert. + :param inputs: the inputs to the model needed to trace the model + :return: the path to the cache created with a unique id to prevent collisions. + """ + TensorflowOnnxAdapter.check_dependency() + cache = TensorflowCache() + + model_file_path = cache.new_cache_path + onnx_file_path = cache.new_cache_path + + tf.saved_model.save(model, model_file_path) + + os.system( + f"python -m tf2onnx.convert --saved-model {model_file_path} --output {onnx_file_path}" + ) + shutil.rmtree(model_file_path) + return onnx_file_path diff --git a/surrealml/engine/utils.py b/surrealml/engine/utils.py index ff69a36..dc413aa 100644 --- a/surrealml/engine/utils.py +++ b/surrealml/engine/utils.py @@ -1,17 +1,35 @@ +""" +This file contains utility functions for the engine. +""" import os import uuid -def create_file_cache_path(): +def create_file_cache_path(cache_folder: str = ".surmlcache") -> os.path: """ Creates a file cache path for the model (creating the file cache if not there). :return: the path to the cache created with a unique id to prevent collisions. """ - cache_folder = '.surmlcache' - if not os.path.exists(cache_folder): os.makedirs(cache_folder) unique_id = str(uuid.uuid4()) file_name = f"{unique_id}.surml" return os.path.join(cache_folder, file_name) + + +class TensorflowCache: + """ + A class to create a cache for tensorflow models. + + Attributes: + cache_path: The path to the cache created with a unique id to prevent collisions. + """ + def __init__(self) -> None: + create_file_cache_path() + self.cache_path = os.path.join(".surmlcache", "tensorflow") + create_file_cache_path(cache_folder=self.cache_path) + + @property + def new_cache_path(self) -> str: + return str(os.path.join(self.cache_path, str(uuid.uuid4()))) diff --git a/surrealml/model_templates/tensorflow/__init__.py b/surrealml/model_templates/tensorflow/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/surrealml/model_templates/tensorflow/tensorflow_linear.py b/surrealml/model_templates/tensorflow/tensorflow_linear.py new file mode 100644 index 0000000..d2ab2ce --- /dev/null +++ b/surrealml/model_templates/tensorflow/tensorflow_linear.py @@ -0,0 +1,96 @@ +""" +Trains a linear regression model in TensorFlow. Should be used for testing certain processes +for linear regression and TensorFlow. +""" +import os +import shutil + +import tensorflow as tf + +from surrealml.model_templates.datasets.house_linear import HOUSE_LINEAR + + +class LinearModel(tf.Module): + def __init__(self, W, b): + super(LinearModel, self).__init__() + self.W = tf.Variable(W, dtype=tf.float32) + self.b = tf.Variable(b, dtype=tf.float32) + + @tf.function(input_signature=[tf.TensorSpec(shape=[None, 2], dtype=tf.float32)]) + def predict(self, x): + return tf.matmul(x, self.W) + self.b + + +def train_model(): + # Convert inputs and outputs to TensorFlow tensors + inputs = tf.constant(HOUSE_LINEAR["inputs"], dtype=tf.float32) + outputs = tf.constant(HOUSE_LINEAR["outputs"], dtype=tf.float32) + + # Model parameters + W = tf.Variable(tf.random.normal([2, 1]), name='weights') # Adjusted for two input features + b = tf.Variable(tf.zeros([1]), name='bias') + + # Training parameters + learning_rate = 0.01 + epochs = 100 + + # Training loop + for epoch in range(epochs): + with tf.GradientTape() as tape: + y_pred = tf.matmul(inputs, W) + b # Adjusted for matrix multiplication + loss = tf.reduce_mean(tf.square(y_pred - outputs)) + + gradients = tape.gradient(loss, [W, b]) + W.assign_sub(learning_rate * gradients[0]) + b.assign_sub(learning_rate * gradients[1]) + + if epoch % 10 == 0: # Print loss every 10 epochs + print(f"Epoch {epoch}: Loss = {loss.numpy()}") + + # Final parameters after training + final_W = W.numpy() + final_b = b.numpy() + + print(f"Trained W: {final_W}, Trained b: {final_b}") + return LinearModel(final_W, final_b) + + +def export_model_tf(model): + """ + Exports the model to TensorFlow SavedModel format. + """ + tf.saved_model.save(model, "linear_regression_model_tf") + return 'linear_regression_model_tf' + + +def export_model_onnx(model): + """ + Exports the model to ONNX format. + + :return: the path to the exported model. + """ + export_model_tf(model) + os.system("python -m tf2onnx.convert --saved-model linear_regression_model_tf --output model.onnx") + + with open("model.onnx", "rb") as f: + onnx_model = f.read() + shutil.rmtree("linear_regression_model_tf") + os.remove("model.onnx") + return onnx_model + + +def export_model_surml(model): + """ + Exports the model to SURML format. + + :param model: the model to export. + :return: the path to the exported model. + """ + from surrealml import SurMlFile, Engine + file = SurMlFile(model=model, name="linear", inputs=HOUSE_LINEAR["inputs"], engine=Engine.TENSORFLOW) + file.add_column("squarefoot") + file.add_column("num_floors") + file.add_normaliser("squarefoot", "z_score", HOUSE_LINEAR["squarefoot"].mean(), HOUSE_LINEAR["squarefoot"].std()) + file.add_normaliser("num_floors", "z_score", HOUSE_LINEAR["num_floors"].mean(), HOUSE_LINEAR["num_floors"].std()) + file.add_output("house_price", "z_score", HOUSE_LINEAR["outputs"].mean(), HOUSE_LINEAR["outputs"].std()) + return file diff --git a/surrealml/surml_file.py b/surrealml/surml_file.py index 552701c..d8e5a5b 100644 --- a/surrealml/surml_file.py +++ b/surrealml/surml_file.py @@ -3,7 +3,7 @@ """ from typing import Optional -from surrealml.engine import Engine, SklearnOnnxAdapter, TorchOnnxAdapter +from surrealml.engine import Engine, SklearnOnnxAdapter, TorchOnnxAdapter, TensorflowOnnxAdapter from surrealml.rust_adapter import RustAdapter @@ -49,6 +49,11 @@ def _cache_model(self) -> Optional[str]: model=self.model, inputs=self.inputs ) + elif self.engine == Engine.TENSORFLOW: + raw_file_path: str = TensorflowOnnxAdapter.save_model_to_onnx( + model=self.model, + inputs=self.inputs + ) else: raise ValueError(f"Engine {self.engine} not supported") return RustAdapter.pass_raw_model_into_rust(raw_file_path) diff --git a/tests/scripts/build_assets.py b/tests/scripts/build_assets.py index 1c525e4..15199f3 100644 --- a/tests/scripts/build_assets.py +++ b/tests/scripts/build_assets.py @@ -25,6 +25,10 @@ from surrealml.model_templates.torch.torch_linear import export_model_onnx as linear_torch_export_model_onnx from surrealml.model_templates.torch.torch_linear import export_model_surml as linear_torch_export_model_surml +from surrealml.model_templates.tensorflow.tensorflow_linear import train_model as linear_tensorflow_train_model +from surrealml.model_templates.tensorflow.tensorflow_linear import export_model_onnx as linear_tensorflow_export_model_onnx +from surrealml.model_templates.tensorflow.tensorflow_linear import export_model_surml as linear_tensorflow_export_model_surml + def delete_directory(dir_path: os.path) -> None: """ @@ -66,13 +70,19 @@ def write_file(file_path: os.path, model, file_name) -> None: core_directory = os.path.join(main_directory, "modules", "core") model_stash_directory = os.path.join(core_directory, "model_stash") + sklearn_stash_directory = os.path.join(model_stash_directory, "sklearn") sklearn_surml_stash_directory = os.path.join(sklearn_stash_directory, "surml") sklearn_onnx_stash_directory = os.path.join(sklearn_stash_directory, "onnx") + torch_stash_directory = os.path.join(model_stash_directory, "torch") torch_surml_stash_directory = os.path.join(torch_stash_directory, "surml") torch_onnx_stash_directory = os.path.join(torch_stash_directory, "onnx") +tensorflow_stash_directory = os.path.join(model_stash_directory, "tensorflow") +tensorflow_surml_stash_directory = os.path.join(tensorflow_stash_directory, "surml") +tensorflow_onnx_stash_directory = os.path.join(tensorflow_stash_directory, "onnx") + target_directory = os.path.join(main_directory, "target") egg_info_dir = os.path.join(main_directory, "surrealml.egg-info") @@ -84,13 +94,19 @@ def main(): delete_directory(model_stash_directory) os.mkdir(model_stash_directory) + os.mkdir(sklearn_stash_directory) os.mkdir(sklearn_surml_stash_directory) os.mkdir(sklearn_onnx_stash_directory) + os.mkdir(torch_stash_directory) os.mkdir(torch_surml_stash_directory) os.mkdir(torch_onnx_stash_directory) + os.mkdir(tensorflow_stash_directory) + os.mkdir(tensorflow_surml_stash_directory) + os.mkdir(tensorflow_onnx_stash_directory) + # train and stash sklearn models sklearn_linear_model = linear_sklearn_train_model() sklearn_linear_surml_file = linear_sklearn_export_model_surml(sklearn_linear_model) @@ -117,6 +133,15 @@ def main(): # os.path.join(torch_onnx_stash_directory, "linear.onnx") # ) + # train and stash tensorflow models + tensorflow_linear_model = linear_tensorflow_train_model() + tensorflow_linear_surml_file = linear_tensorflow_export_model_surml(tensorflow_linear_model) + tensorflow_linear_onnx_file = linear_tensorflow_export_model_onnx(tensorflow_linear_model) + + tensorflow_linear_surml_file.save( + path=str(os.path.join(tensorflow_surml_stash_directory, "linear.surml")) + ) + os.system(f"cd {model_stash_directory} && tree") shutil.rmtree(".surmlcache")