-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* adding glue module to enable error handling (#27) * added batching runner (#28) * Kings college london integration (#30) * adding build using binary downloads (#8) * adding build using binary downloads * sorting out the build.rs * updating build.rs for surrealml package * prepping version for release * now has target tracking (#10) * adding check in build.rs for docs.rs * removing build.rs for main surrealml to ensure that libraries using the core do not need to do anything in their build.rs * adding machine learning pipelines for bioengineering projects at Kings College London * Remove integrated_training_runner/run_env/ from tracking * adding machine learning pipelines for bioengineering projects at Kings College London * Update FFmpeg data access module and README (#29) * adding run_env to the gitignore --------- Co-authored-by: Yang Li <oliverlee2018@163.com> * bumping the version * updating the README and module * updating the surrealml-core deployment workflow * updating the surrealml-core deployment workflow * updating cargo * Error modules (#36) * Develop (#35) * adding glue module to enable error handling (#27) * added batching runner (#28) * Kings college london integration (#30) * adding build using binary downloads (#8) * adding build using binary downloads * sorting out the build.rs * updating build.rs for surrealml package * prepping version for release * now has target tracking (#10) * adding check in build.rs for docs.rs * removing build.rs for main surrealml to ensure that libraries using the core do not need to do anything in their build.rs * adding machine learning pipelines for bioengineering projects at Kings College London * Remove integrated_training_runner/run_env/ from tracking * adding machine learning pipelines for bioengineering projects at Kings College London * Update FFmpeg data access module and README (#29) * adding run_env to the gitignore --------- Co-authored-by: Yang Li <oliverlee2018@163.com> * bumping the version * updating the README and module * updating the surrealml-core deployment workflow * updating the surrealml-core deployment workflow * updating cargo --------- Co-authored-by: Sam Hillman <116303632+SHillman836@users.noreply.github.com> Co-authored-by: Yang Li <oliverlee2018@163.com> * merging error modules into the core * merging error modules into the core * merging error modules into the core --------- Co-authored-by: Sam Hillman <116303632+SHillman836@users.noreply.github.com> Co-authored-by: Yang Li <oliverlee2018@163.com> * Index overflow (#40) * adding buffer out of index check * adding buffer out of index check * updating testing around meta data (#42) * updating the naming and increasing tests around the meta data of the stored ML models * updating the naming and increasing tests around the meta data of the stored ML models * updating the naming and increasing tests around the meta data of the stored ML models * Tensorflow support (#44) * stashing for branch switch * adding tests for tensorflow * adding tests for tensorflow * fixing requirement conflicts * fixing requirement conflicts * fixing requirement conflicts * fixing requirement conflicts * fixing requirement conflicts --------- Co-authored-by: Sam Hillman <116303632+SHillman836@users.noreply.github.com> Co-authored-by: Yang Li <oliverlee2018@163.com>
- Loading branch information
1 parent
7d129de
commit 9f2a7bf
Showing
17 changed files
with
253 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
numpy==1.26.3 | ||
skl2onnx==1.16.0 | ||
scikit-learn==1.4.0 | ||
torch==2.1.2 | ||
onnx==1.15.0 | ||
onnxruntime==1.16.3 | ||
numpy | ||
skl2onnx | ||
scikit-learn | ||
torch | ||
tf2onnx | ||
tensorflow | ||
onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import shutil | ||
try: | ||
import tf2onnx | ||
import tensorflow as tf | ||
except ImportError: | ||
tf2onnx = None | ||
tf = None | ||
|
||
from surrealml.engine.utils import TensorflowCache | ||
|
||
|
||
class TensorflowOnnxAdapter: | ||
|
||
@staticmethod | ||
def check_dependency() -> None: | ||
""" | ||
Checks if the tensorflow dependency is installed raising an error if not. | ||
Please call this function when performing any tensorflow related operations. | ||
""" | ||
if tf2onnx is None or tf is None: | ||
raise ImportError("tensorflow feature needs to be installed to use tensorflow features") | ||
|
||
@staticmethod | ||
def save_model_to_onnx(model, inputs) -> str: | ||
""" | ||
Saves a tensorflow model to an onnx file. | ||
:param model: the tensorflow model to convert. | ||
:param inputs: the inputs to the model needed to trace the model | ||
:return: the path to the cache created with a unique id to prevent collisions. | ||
""" | ||
TensorflowOnnxAdapter.check_dependency() | ||
cache = TensorflowCache() | ||
|
||
model_file_path = cache.new_cache_path | ||
onnx_file_path = cache.new_cache_path | ||
|
||
tf.saved_model.save(model, model_file_path) | ||
|
||
os.system( | ||
f"python -m tf2onnx.convert --saved-model {model_file_path} --output {onnx_file_path}" | ||
) | ||
shutil.rmtree(model_file_path) | ||
return onnx_file_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,35 @@ | ||
""" | ||
This file contains utility functions for the engine. | ||
""" | ||
import os | ||
import uuid | ||
|
||
|
||
def create_file_cache_path(): | ||
def create_file_cache_path(cache_folder: str = ".surmlcache") -> os.path: | ||
""" | ||
Creates a file cache path for the model (creating the file cache if not there). | ||
:return: the path to the cache created with a unique id to prevent collisions. | ||
""" | ||
cache_folder = '.surmlcache' | ||
|
||
if not os.path.exists(cache_folder): | ||
os.makedirs(cache_folder) | ||
unique_id = str(uuid.uuid4()) | ||
file_name = f"{unique_id}.surml" | ||
return os.path.join(cache_folder, file_name) | ||
|
||
|
||
class TensorflowCache: | ||
""" | ||
A class to create a cache for tensorflow models. | ||
Attributes: | ||
cache_path: The path to the cache created with a unique id to prevent collisions. | ||
""" | ||
def __init__(self) -> None: | ||
create_file_cache_path() | ||
self.cache_path = os.path.join(".surmlcache", "tensorflow") | ||
create_file_cache_path(cache_folder=self.cache_path) | ||
|
||
@property | ||
def new_cache_path(self) -> str: | ||
return str(os.path.join(self.cache_path, str(uuid.uuid4()))) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Trains a linear regression model in TensorFlow. Should be used for testing certain processes | ||
for linear regression and TensorFlow. | ||
""" | ||
import os | ||
import shutil | ||
|
||
import tensorflow as tf | ||
|
||
from surrealml.model_templates.datasets.house_linear import HOUSE_LINEAR | ||
|
||
|
||
class LinearModel(tf.Module): | ||
def __init__(self, W, b): | ||
super(LinearModel, self).__init__() | ||
self.W = tf.Variable(W, dtype=tf.float32) | ||
self.b = tf.Variable(b, dtype=tf.float32) | ||
|
||
@tf.function(input_signature=[tf.TensorSpec(shape=[None, 2], dtype=tf.float32)]) | ||
def predict(self, x): | ||
return tf.matmul(x, self.W) + self.b | ||
|
||
|
||
def train_model(): | ||
# Convert inputs and outputs to TensorFlow tensors | ||
inputs = tf.constant(HOUSE_LINEAR["inputs"], dtype=tf.float32) | ||
outputs = tf.constant(HOUSE_LINEAR["outputs"], dtype=tf.float32) | ||
|
||
# Model parameters | ||
W = tf.Variable(tf.random.normal([2, 1]), name='weights') # Adjusted for two input features | ||
b = tf.Variable(tf.zeros([1]), name='bias') | ||
|
||
# Training parameters | ||
learning_rate = 0.01 | ||
epochs = 100 | ||
|
||
# Training loop | ||
for epoch in range(epochs): | ||
with tf.GradientTape() as tape: | ||
y_pred = tf.matmul(inputs, W) + b # Adjusted for matrix multiplication | ||
loss = tf.reduce_mean(tf.square(y_pred - outputs)) | ||
|
||
gradients = tape.gradient(loss, [W, b]) | ||
W.assign_sub(learning_rate * gradients[0]) | ||
b.assign_sub(learning_rate * gradients[1]) | ||
|
||
if epoch % 10 == 0: # Print loss every 10 epochs | ||
print(f"Epoch {epoch}: Loss = {loss.numpy()}") | ||
|
||
# Final parameters after training | ||
final_W = W.numpy() | ||
final_b = b.numpy() | ||
|
||
print(f"Trained W: {final_W}, Trained b: {final_b}") | ||
return LinearModel(final_W, final_b) | ||
|
||
|
||
def export_model_tf(model): | ||
""" | ||
Exports the model to TensorFlow SavedModel format. | ||
""" | ||
tf.saved_model.save(model, "linear_regression_model_tf") | ||
return 'linear_regression_model_tf' | ||
|
||
|
||
def export_model_onnx(model): | ||
""" | ||
Exports the model to ONNX format. | ||
:return: the path to the exported model. | ||
""" | ||
export_model_tf(model) | ||
os.system("python -m tf2onnx.convert --saved-model linear_regression_model_tf --output model.onnx") | ||
|
||
with open("model.onnx", "rb") as f: | ||
onnx_model = f.read() | ||
shutil.rmtree("linear_regression_model_tf") | ||
os.remove("model.onnx") | ||
return onnx_model | ||
|
||
|
||
def export_model_surml(model): | ||
""" | ||
Exports the model to SURML format. | ||
:param model: the model to export. | ||
:return: the path to the exported model. | ||
""" | ||
from surrealml import SurMlFile, Engine | ||
file = SurMlFile(model=model, name="linear", inputs=HOUSE_LINEAR["inputs"], engine=Engine.TENSORFLOW) | ||
file.add_column("squarefoot") | ||
file.add_column("num_floors") | ||
file.add_normaliser("squarefoot", "z_score", HOUSE_LINEAR["squarefoot"].mean(), HOUSE_LINEAR["squarefoot"].std()) | ||
file.add_normaliser("num_floors", "z_score", HOUSE_LINEAR["num_floors"].mean(), HOUSE_LINEAR["num_floors"].std()) | ||
file.add_output("house_price", "z_score", HOUSE_LINEAR["outputs"].mean(), HOUSE_LINEAR["outputs"].std()) | ||
return file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.