diff --git a/.gitignore b/.gitignore
index efa62b0c6..c9fe0af83 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,11 @@
*.pyo
*.swp
*~
+*.log
+*.db
+*.pt
+*.rst
+*.tar.gz
.DS_Store
.cache
.classpath
@@ -22,3 +27,21 @@
.settings
hs_err*.log
target
+/examples/ML+DL-Examples/Spark-DL/dl_inference/pytriton
+/examples/ML+DL-Examples/Spark-DL/dl_inference/archive
+# /examples/ML+DL-Examples/Spark-Rapids-ML/pca
+/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/new_ideas
+/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/torch_notebooks
+/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/tf_notebooks
+/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/datasets
+/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/spark-dl-datasets
+/examples/ML+DL-Examples/Spark-DL/dl_inference/pytorch/models
+**/spark-dl-datasets
+**/models
+**/predictions
+**/dl_inference/databricks/*.ipynb
+**/dl_inference/dataproc/*.ipynb
+**/dl_inference/huggingface/.ipynb_checkpoints
+/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/spark-dl-notebooks
+/examples/ML+DL-Examples/Spark-DL/dl_inference/test.ipynb
+/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/copy_notebooks_gcp.sh
\ No newline at end of file
diff --git a/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a-test.json b/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a-test.json
new file mode 100644
index 000000000..27ccba1bd
--- /dev/null
+++ b/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a-test.json
@@ -0,0 +1 @@
+{"version":"1.4.0","project":{"name":"spark-rapids-examples_f4dd477a3a-test","directory":"/home/rishic/spark-rapids-examples/.scala-build","workspaceDir":"/home/rishic/spark-rapids-examples","sources":["/home/rishic/spark-rapids-examples/examples/MIG-Support/device-plugins/gpu-mig/src/test/java/com/nvidia/spark/TestNvidiaGPUMigPluginForRuntimeV2.java"],"dependencies":["spark-rapids-examples_f4dd477a3a"],"classpath":["/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12.jar","/home/rishic/spark-rapids-examples/.scala-build/spark-rapids-examples_f4dd477a3a/classes/main","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/com/sourcegraph/semanticdb-javac/0.10.0/semanticdb-javac-0.10.0.jar"],"out":"/home/rishic/spark-rapids-examples/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a-test","classesDir":"/home/rishic/spark-rapids-examples/.scala-build/spark-rapids-examples_f4dd477a3a/classes/test","scala":{"organization":"org.scala-lang","name":"scala-compiler","version":"3.4.2","options":["-Xsemanticdb","-sourceroot","/home/rishic/spark-rapids-examples"],"jars":["/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-compiler_3/3.4.2/scala3-compiler_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-interfaces/3.4.2/scala3-interfaces-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/tasty-core_3/3.4.2/tasty-core_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/modules/scala-asm/9.6.0-scala-1/scala-asm-9.6.0-scala-1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-sbt/compiler-interface/1.9.6/compiler-interface-1.9.6.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-reader/3.25.1/jline-reader-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-terminal/3.25.1/jline-terminal-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-terminal-jna/3.25.1/jline-terminal-jna-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-sbt/util-interface/1.9.8/util-interface-1.9.8.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-native/3.25.1/jline-native-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/net/java/dev/jna/jna/5.14.0/jna-5.14.0.jar"],"bridgeJars":["/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-sbt-bridge/3.4.2/scala3-sbt-bridge-3.4.2.jar"]},"java":{"options":["-Xplugin:semanticdb -sourceroot:/home/rishic/spark-rapids-examples -targetroot:javac-classes-directory","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED"]},"test":{"frameworks":[{"names":["com.novocode.junit.JUnitFramework"]},{"names":["org.scalatest.tools.Framework","org.scalatest.tools.ScalaTestFramework"]},{"names":["org.scalacheck.ScalaCheckFramework"]},{"names":["org.specs.runner.SpecsFramework","org.specs2.runner.Specs2Framework","org.specs2.runner.SpecsFramework"]},{"names":["utest.runner.Framework"]},{"names":["munit.Framework"]}],"options":{"excludes":[],"arguments":[]}},"platform":{"name":"jvm","config":{"home":"/usr/lib/jvm/java-8-openjdk-amd64","options":[]},"mainClass":[]},"resolution":{"modules":[{"organization":"org.scala-lang","name":"scala3-library_3","version":"3.4.2","artifacts":[{"name":"scala3-library_3","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2.jar"},{"name":"scala3-library_3","classifier":"sources","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2-sources.jar"}]},{"organization":"org.scala-lang","name":"scala-library","version":"2.13.12","artifacts":[{"name":"scala-library","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12.jar"},{"name":"scala-library","classifier":"sources","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12-sources.jar"}]}]},"tags":["test"]}}
\ No newline at end of file
diff --git a/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a.json b/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a.json
new file mode 100644
index 000000000..525e8709f
--- /dev/null
+++ b/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a.json
@@ -0,0 +1 @@
+{"version":"1.4.0","project":{"name":"spark-rapids-examples_f4dd477a3a","directory":"/home/rishic/spark-rapids-examples/.scala-build","workspaceDir":"/home/rishic/spark-rapids-examples","sources":["/home/rishic/spark-rapids-examples/examples/MIG-Support/device-plugins/gpu-mig/src/main/java/com/nvidia/spark/NvidiaGPUMigPluginForRuntimeV2.java","/home/rishic/spark-rapids-examples/examples/ML+DL-Examples/Spark-cuML/pca/main.scala","/home/rishic/spark-rapids-examples/examples/ML+DL-Examples/Spark-cuML/pca/scala/src/com/nvidia/spark/examples/pca/Main.scala","/home/rishic/spark-rapids-examples/examples/ML+DL-Examples/Spark-cuML/pca/target/classes/com/nvidia/spark/examples/pca/Main.scala","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/DecimalFraction.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/StringWordCount.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/URLDecode.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/hive/URLEncode.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/CosineSimilarity.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/DecimalFraction.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/NativeUDFExamplesLoader.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/URLDecode.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/java/com/nvidia/spark/rapids/udf/java/URLEncode.java","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/scala/com/nvidia/spark/rapids/udf/scala/URLDecode.scala","/home/rishic/spark-rapids-examples/examples/UDF-Examples/RAPIDS-accelerated-UDFs/src/main/scala/com/nvidia/spark/rapids/udf/scala/URLEncode.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/agaricus/scala/src/com/nvidia/spark/examples/agaricus/Main.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/agaricus/target/classes/com/nvidia/spark/examples/agaricus/Main.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/CrossValidationMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/ETLMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/Main.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/Mortgage.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/scala/src/com/nvidia/spark/examples/mortgage/XGBoostETL.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/target/classes/com/nvidia/spark/examples/mortgage/CrossValidationMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/target/classes/com/nvidia/spark/examples/mortgage/ETLMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/target/classes/com/nvidia/spark/examples/mortgage/Main.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/target/classes/com/nvidia/spark/examples/mortgage/Mortgage.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/mortgage/target/classes/com/nvidia/spark/examples/mortgage/XGBoostETL.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/CrossValidationMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/ETLMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/Main.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/scala/src/com/nvidia/spark/examples/taxi/Taxi.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/target/classes/com/nvidia/spark/examples/taxi/CrossValidationMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/target/classes/com/nvidia/spark/examples/taxi/ETLMain.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/target/classes/com/nvidia/spark/examples/taxi/Main.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/taxi/target/classes/com/nvidia/spark/examples/taxi/Taxi.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/Benchmark.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/SparkSetup.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/Vectorize.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/scala/src/com/nvidia/spark/examples/utility/XGBoostArgs.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/target/classes/com/nvidia/spark/examples/utility/Benchmark.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/target/classes/com/nvidia/spark/examples/utility/SparkSetup.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/target/classes/com/nvidia/spark/examples/utility/Vectorize.scala","/home/rishic/spark-rapids-examples/examples/XGBoost-Examples/utility/target/classes/com/nvidia/spark/examples/utility/XGBoostArgs.scala"],"dependencies":[],"classpath":["/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/com/sourcegraph/semanticdb-javac/0.10.0/semanticdb-javac-0.10.0.jar"],"out":"/home/rishic/spark-rapids-examples/.scala-build/.bloop/spark-rapids-examples_f4dd477a3a","classesDir":"/home/rishic/spark-rapids-examples/.scala-build/spark-rapids-examples_f4dd477a3a/classes/main","scala":{"organization":"org.scala-lang","name":"scala-compiler","version":"3.4.2","options":["-Xsemanticdb","-sourceroot","/home/rishic/spark-rapids-examples"],"jars":["/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-compiler_3/3.4.2/scala3-compiler_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-interfaces/3.4.2/scala3-interfaces-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/tasty-core_3/3.4.2/tasty-core_3-3.4.2.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/modules/scala-asm/9.6.0-scala-1/scala-asm-9.6.0-scala-1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-sbt/compiler-interface/1.9.6/compiler-interface-1.9.6.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-reader/3.25.1/jline-reader-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-terminal/3.25.1/jline-terminal-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-terminal-jna/3.25.1/jline-terminal-jna-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-sbt/util-interface/1.9.8/util-interface-1.9.8.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/jline/jline-native/3.25.1/jline-native-3.25.1.jar","/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/net/java/dev/jna/jna/5.14.0/jna-5.14.0.jar"],"bridgeJars":["/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-sbt-bridge/3.4.2/scala3-sbt-bridge-3.4.2.jar"]},"java":{"options":["-Xplugin:semanticdb -sourceroot:/home/rishic/spark-rapids-examples -targetroot:javac-classes-directory","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.model=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED","-J--add-exports","-Jjdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED"]},"platform":{"name":"jvm","config":{"home":"/usr/lib/jvm/java-8-openjdk-amd64","options":[]},"mainClass":[]},"resolution":{"modules":[{"organization":"org.scala-lang","name":"scala3-library_3","version":"3.4.2","artifacts":[{"name":"scala3-library_3","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2.jar"},{"name":"scala3-library_3","classifier":"sources","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala3-library_3/3.4.2/scala3-library_3-3.4.2-sources.jar"}]},{"organization":"org.scala-lang","name":"scala-library","version":"2.13.12","artifacts":[{"name":"scala-library","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12.jar"},{"name":"scala-library","classifier":"sources","path":"/home/rishic/.cache/coursier/v1/https/repo1.maven.org/maven2/org/scala-lang/scala-library/2.13.12/scala-library-2.13.12-sources.jar"}]}]},"tags":["library"]}}
\ No newline at end of file
diff --git a/.vscode/settings.json b/.vscode/settings.json
new file mode 100644
index 000000000..c592bddd0
--- /dev/null
+++ b/.vscode/settings.json
@@ -0,0 +1,6 @@
+{
+ "files.watcherExclude": {
+ "**/target": true
+ },
+ "java.compile.nullAnalysis.mode": "disabled"
+}
\ No newline at end of file
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md b/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md
index d704f2995..034ee2af5 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/README.md
@@ -1,11 +1,17 @@
-# Spark DL Inference Using External Frameworks
+# Deep Learning Inference on Spark
-Example notebooks for the [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf) function introduced in Spark 3.4.
+Example notebooks demonstrating **distributed deep learning inference** using the [predict_batch_udf](https://developer.nvidia.com/blog/distributed-deep-learning-made-easy-with-spark-3-4/) introduced in Spark 3.4.0.
+These notebooks also demonstrate integration with [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html), an open-source, GPU-accelerated serving solution for DL.
-## Overview
+## Contents:
+- [Overview](#overview)
+- [Running Locally](#running-locally)
+- [Running on Cloud](#running-on-cloud-environments)
+- [Integration with Triton Inference Server](#inference-with-triton)
-This directory contains notebooks for each DL framework (based on their own published examples). The goal is to demonstrate how models trained and saved on single-node machines can be easily used for parallel inferencing on Spark clusters.
+## Overview
+These notebooks demonstrate how models from external frameworks (Torch, Huggingface, Tensorflow) trained on single-worker machines can be used for large-scale distributed inference on Spark clusters.
For example, a basic model trained in TensorFlow and saved on disk as "mnist_model" can be used in Spark as follows:
```
import numpy as np
@@ -28,35 +34,33 @@ df = spark.read.parquet("mnist_data")
predictions = df.withColumn("preds", mnist("data")).collect()
```
-In this simple case, the `predict_batch_fn` will use TensorFlow APIs to load the model and return a simple `predict` function which operates on numpy arrays. The `predict_batch_udf` will automatically convert the Spark DataFrame columns to the expected numpy inputs.
+In this simple case, the `predict_batch_fn` will use TensorFlow APIs to load the model and return a simple `predict` function. The `predict_batch_udf` will handle the data conversion from Spark DataFrame columns into batched numpy inputs.
+
-All notebooks have been saved with sample outputs for quick browsing.
-Here is a full list of the notebooks with their published example links:
+#### Notebook List
-| | Category | Notebook Name | Description | Link
+Below is a full list of the notebooks with links to the examples they are based on. All notebooks have been saved with sample outputs for quick browsing.
+
+| | Framework | Notebook Name | Description | Link
| ------------- | ------------- | ------------- | ------------- | -------------
| 1 | PyTorch | Image Classification | Training a model to predict clothing categories in FashionMNIST, including accelerated inference with Torch-TensorRT. | [Link](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html)
-| 2 | PyTorch | Regression | Training a model to predict housing prices in the California Housing Dataset, including accelerated inference with Torch-TensorRT. | [Link](https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md)
+| 2 | PyTorch | Housing Regression | Training a model to predict housing prices in the California Housing Dataset, including accelerated inference with Torch-TensorRT. | [Link](https://github.com/christianversloot/machine-learning-articles/blob/main/how-to-create-a-neural-network-for-regression-with-pytorch.md)
| 3 | Tensorflow | Image Classification | Training a model to predict hand-written digits in MNIST. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/save_and_load.ipynb)
-| 4 | Tensorflow | Feature Columns | Training a model with preprocessing layers to predict likelihood of pet adoption in the PetFinder mini dataset. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/structured_data/preprocessing_layers.ipynb)
-| 5 | Tensorflow | Keras Metadata | Training ResNet-50 to perform flower recognition on Databricks. | [Link](https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/keras-metadata.html)
+| 4 | Tensorflow | Keras Preprocessing | Training a model with preprocessing layers to predict likelihood of pet adoption in the PetFinder mini dataset. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/structured_data/preprocessing_layers.ipynb)
+| 5 | Tensorflow | Keras Resnet50 | Training ResNet-50 to perform flower recognition from flower images. | [Link](https://docs.databricks.com/en/_extras/notebooks/source/deep-learning/keras-metadata.html)
| 6 | Tensorflow | Text Classification | Training a model to perform sentiment analysis on the IMDB dataset. | [Link](https://github.com/tensorflow/docs/blob/master/site/en/tutorials/keras/text_classification.ipynb)
-| 7+8 | HuggingFace | Conditional Generation | Sentence translation using the T5 text-to-text transformer, with notebooks demoing both Torch and Tensorflow. | [Link](https://huggingface.co/docs/transformers/model_doc/t5#t5)
-| 9+10 | HuggingFace | Pipelines | Sentiment analysis using Huggingface pipelines, with notebooks demoing both Torch and Tensorflow. | [Link](https://huggingface.co/docs/transformers/quicktour#pipeline-usage)
-| 11 | HuggingFace | Sentence Transformers | Sentence embeddings using the SentenceTransformers framework in Torch. | [Link](https://huggingface.co/sentence-transformers)
+| 7+8 | HuggingFace | Conditional Generation | Sentence translation using the T5 text-to-text transformer for both Torch and Tensorflow. | [Link](https://huggingface.co/docs/transformers/model_doc/t5#t5)
+| 9+10 | HuggingFace | Pipelines | Sentiment analysis using Huggingface pipelines for both Torch and Tensorflow. | [Link](https://huggingface.co/docs/transformers/quicktour#pipeline-usage)
+| 11 | HuggingFace | Sentence Transformers | Sentence embeddings using SentenceTransformers in Torch. | [Link](https://huggingface.co/sentence-transformers)
-## Running the Notebooks
+## Running Locally
-If you want to run the notebooks yourself, please follow these instructions.
-
-**Notes**:
-- The notebooks require a GPU environment for the executors.
-- Please create separate environments for PyTorch and Tensorflow examples as specified below. This will avoid conflicts between the CUDA libraries bundled with their respective versions. The Huggingface examples will have a `_torch` or `_tf` suffix to specify the environment used.
-- The PyTorch notebooks include model compilation and accelerated inference with TensorRT. While not included in the notebooks, Tensorflow also supports [integration with TensorRT](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html), but may require downgrading the TF version.
-- For demonstration purposes, these examples just use a local Spark Standalone cluster with a single executor, but you should be able to run them on any distributed Spark cluster.
+To run the notebooks locally, please follow these instructions:
#### Create environment
+Each notebook has a suffix `_torch` or `_tf` specifying the environment used.
+
**For PyTorch:**
```
conda create -n spark-dl-torch python=3.11
@@ -70,36 +74,57 @@ conda activate spark-dl-tf
pip install -r tf_requirements.txt
```
-#### Launch Jupyter + Spark
+#### Start Cluster
+
+For demonstration, these instructions just use a local Standalone cluster with a single executor, but they can be run on any distributed Spark cluster. For cloud environments, see [below](#running-on-cloud-environments).
+```shell
+# Replace with your Spark installation path
+export SPARK_HOME=
```
-# setup environment variables
-export SPARK_HOME=/path/to/spark
+
+```shell
+# Configure and start cluster
export MASTER=spark://$(hostname):7077
export SPARK_WORKER_INSTANCES=1
export CORES_PER_WORKER=8
-export PYSPARK_DRIVER_PYTHON=jupyter
-export PYSPARK_DRIVER_PYTHON_OPTS='lab'
-
-# start spark standalone cluster
+export SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=1 -Dspark.worker.resource.gpu.discoveryScript=$SPARK_HOME/examples/src/main/scripts/getGpusResources.sh"
${SPARK_HOME}/sbin/start-master.sh; ${SPARK_HOME}/sbin/start-worker.sh -c ${CORES_PER_WORKER} -m 16G ${MASTER}
+```
-# start jupyter with pyspark
-${SPARK_HOME}/bin/pyspark --master ${MASTER} \
---driver-memory 8G \
---executor-memory 8G \
---conf spark.python.worker.reuse=True
+The notebooks are ready to run! Each notebook has a cell to connect to the standalone cluster and create a SparkSession.
-# BROWSE to localhost:8888 to view/run notebooks
+**Notes**:
+- Please create separate environments for PyTorch and Tensorflow notebooks as specified above. This will avoid conflicts between the CUDA libraries bundled with their respective versions.
+- `requirements.txt` installs pyspark>=3.4.0. Make sure the installed PySpark version is compatible with your system's Spark installation.
+- The notebooks require a GPU environment for the executors.
+- The PyTorch notebooks include model compilation and accelerated inference with TensorRT. While not included in the notebooks, Tensorflow also supports [integration with TensorRT](https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html), but as of writing it is not supported in TF==2.17.0.
-# stop spark standalone cluster
-${SPARK_HOME}/sbin/stop-worker.sh; ${SPARK_HOME}/sbin/stop-master.sh
+**Troubleshooting:**
+If you encounter issues starting the Triton server, you may need to link your libstdc++ file to the conda environment, e.g.:
+```shell
+ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_PREFIX}/lib/libstdc++.so.6
```
-## Triton Inference Server
+## Running on Cloud Environments
+
+We also provide instructions to run the notebooks on CSP Spark environments.
+See the instructions for [Databricks](databricks/README.md) and [GCP Dataproc](dataproc/README.md).
+
+## Inference with Triton
+
+The notebooks also demonstrate integration with the [Triton Inference Server](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html), an open-source serving platform for deep learning models, which includes many [features and performance optimizations](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/index.html#triton-major-features) to streamline inference.
+The notebooks use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like Python framework that handles communication with the Triton server.
+
+
-The example notebooks also demonstrate integration with [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.
+The diagram above shows how Spark distributes inference tasks to run on the Triton Inference Server, with PyTriton handling request/response communication with the server.
-**Note**: Some examples may require special configuration of server as highlighted in the notebooks.
+The process looks like this:
+- Distribute a PyTriton task across the Spark cluster, instructing each worker to launch a Triton server process.
+ - Use stage-level scheduling to ensure there is a 1:1 mapping between worker nodes and servers.
+- Define a Triton inference function, which contains a client that binds to the local server on a given worker and sends inference requests.
+- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.
+- Finally, distribute a shutdown signal to terminate the Triton server processes on each worker.
-**Note**: for demonstration purposes, the Triton Inference Server integrations just launch the server in a docker container on the local host, so you will need to [install docker](https://docs.docker.com/engine/install/) on your local host. Most real-world deployments will likely be hosted on remote machines.
+For more information on how PyTriton works, see the [PyTriton docs](https://triton-inference-server.github.io/pytriton/latest/high_level_design/).
\ No newline at end of file
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/README.md b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/README.md
new file mode 100644
index 000000000..67fb6633f
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/README.md
@@ -0,0 +1,79 @@
+# Spark DL Inference on AWS EMR
+
+The instructions assume you already have a AWS EMR account.
+**Note**: fields in \ require user inputs.
+
+## Setup
+
+#### Setup AWS CLI
+
+1. Install the [AWS CLI](https://docs.aws.amazon.com/emr/latest/EMR-on-EKS-DevelopmentGuide/setting-up-cli.html).
+
+2. Initialize the CLI via `aws configure`. You may need to create access keys by following [Authenticating using IAM user credentials](https://docs.aws.amazon.com/cli/latest/userguide/cli-authentication-user.html). You can find your default region name (e.g. Ohio) on the right of the top navigation bar. Clicking the region name will show the region code (e.g. us-east-2 for Ohio).
+
+ ```shell
+ aws configure
+ AWS Access Key ID [None]:
+ AWS Secret Access Key [None]:
+ Default region name [None]:
+ Default output format [None]: json
+ ```
+
+#### Copy Files to S3
+
+3. Create an S3 bucket if you don't already have one.
+ ```shell
+ export S3_BUCKET=
+ aws s3 mb s3://${S3_BUCKET}
+ ```
+
+4. Upload the initialization script to S3.
+ ```shell
+ aws s3 cp spark_rapids_ml.zip s3://${S3_BUCKET}/spark_rapids_ml.zip
+ cd ../../notebooks/aws-emr
+ aws s3 cp init-bootstrap-action.sh s3://${S3_BUCKET}/init-bootstrap-action.sh
+ ```
+- Print out available subnets in CLI then pick a SubnetId (e.g. subnet-0744566f of AvailabilityZone us-east-2a).
+ ```shell
+ aws ec2 describe-subnets
+ export SUBNET_ID=
+ ```
+
+- Create a cluster with at least two single-gpu workers. You will obtain a ClusterId in terminal. Noted three GPU nodes are requested here, because EMR cherry picks one node (either CORE or TASK) to run JupyterLab service for notebooks and will not use the node for compute.
+ ```shell
+ export CLUSTER_NAME="spark_rapids_ml"
+ export CUR_DIR=$(pwd)
+
+ aws emr create-cluster \
+ --name ${CLUSTER_NAME} \
+ --release-label emr-7.3.0 \
+ --ebs-root-volume-size=32 \
+ --applications Name=Hadoop Name=Livy Name=Spark Name=JupyterEnterpriseGateway \
+ --service-role EMR_DefaultRole \
+ --log-uri s3://${S3_BUCKET}/logs \
+ --ec2-attributes SubnetId=${SUBNET_ID},InstanceProfile=EMR_EC2_DefaultRole \
+ --instance-groups InstanceGroupType=MASTER,InstanceCount=1,InstanceType=m4.2xlarge \
+ InstanceGroupType=CORE,InstanceCount=3,InstanceType=g4dn.2xlarge \
+ --configurations file://${CUR_DIR}/init-configurations.json \
+ --bootstrap-actions Name='Spark Rapids ML Bootstrap action',Path=s3://${S3_BUCKET}/init-bootstrap-action.sh
+ ```
+
+- In the [AWS EMR console](https://console.aws.amazon.com/emr/), click "Clusters", you can find the cluster id of the created cluster. Wait until all the instances have the Status turned to "Running".
+- In the [AWS EMR console](https://console.aws.amazon.com/emr/), click "Workspace(Notebooks)", then create a workspace. Wait until the status becomes ready and a JupyterLab webpage will pop up.
+
+- Enter the created workspace. Click the "Cluster" button (usually the top second button of the left navigation bar). Attach the workspace to the newly created cluster through cluster id.
+
+- Use the default notebook or create/upload a new notebook. Set the notebook kernel to "PySpark".
+
+- Add the following to a new cell at the beginning of the notebook. Replace "s3://path/to/spark\_rapids\_ml.zip" with the actual s3 path.
+ ```python
+ %%configure -f
+ {
+ "conf":{
+ "spark.submit.pyFiles": "s3://path/to/spark_rapids_ml.zip"
+ }
+ }
+
+ ```
+- Run the notebook cells.
+ **Note**: these settings are for demonstration purposes only. Additional tuning may be required for optimal performance.
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init-config-spark-rapids.json b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init-config-spark-rapids.json
new file mode 100644
index 000000000..67e10640f
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init-config-spark-rapids.json
@@ -0,0 +1,85 @@
+[
+ {
+ "Classification":"spark",
+ "Properties":{
+ "enableSparkRapids":"true"
+ }
+ },
+ {
+ "Classification":"yarn-site",
+ "Properties":{
+ "yarn.nodemanager.resource-plugins":"yarn.io/gpu",
+ "yarn.resource-types":"yarn.io/gpu",
+ "yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices":"auto",
+ "yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables":"/usr/bin",
+ "yarn.nodemanager.linux-container-executor.cgroups.mount":"true",
+ "yarn.nodemanager.linux-container-executor.cgroups.mount-path":"/spark-rapids-cgroup",
+ "yarn.nodemanager.linux-container-executor.cgroups.hierarchy":"yarn",
+ "yarn.nodemanager.container-executor.class":"org.apache.hadoop.yarn.server.nodemanager.LinuxContainerExecutor"
+ }
+ },
+ {
+ "Classification":"container-executor",
+ "Properties":{
+
+ },
+ "Configurations":[
+ {
+ "Classification":"gpu",
+ "Properties":{
+ "module.enabled":"true"
+ }
+ },
+ {
+ "Classification":"cgroups",
+ "Properties":{
+ "root":"/spark-rapids-cgroup",
+ "yarn-hierarchy":"yarn"
+ }
+ }
+ ]
+ },
+ {
+ "Classification":"spark-defaults",
+ "Properties":{
+ "spark.plugins":"com.nvidia.spark.SQLPlugin",
+ "spark.sql.sources.useV1SourceList":"",
+ "spark.executor.resource.gpu.discoveryScript":"/usr/lib/spark/scripts/gpu/getGpusResources.sh",
+ "spark.executor.extraLibraryPath":"/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/lib/hadoop/lib/native:/usr/lib/hadoop-lzo/lib/native:/docker/usr/lib/hadoop/lib/native:/docker/usr/lib/hadoop-lzo/lib/native",
+ "spark.rapids.sql.concurrentGpuTasks":"2",
+ "spark.executor.resource.gpu.amount":"1",
+ "spark.executor.cores":"8",
+ "spark.task.resource.gpu.amount":"0.125",
+ "spark.executor.memoryOverhead":"2G",
+ "spark.sql.files.maxPartitionBytes":"256m",
+ "spark.sql.adaptive.enabled":"false",
+ "spark.python.worker.reuse":"true",
+ "spark.rapids.memory.gpu.minAllocFraction":"0.0001",
+ "spark.rapids.memory.gpu.pooling.enabled":"false",
+ "spark.rapids.sql.explain":"NONE",
+ "spark.rapids.memory.gpu.reserve":"20",
+ "spark.rapids.sql.python.gpu.enabled":"true",
+ "spark.rapids.memory.pinnedPool.size":"2G",
+ "spark.rapids.sql.batchSizeBytes":"512m",
+ "spark.locality.wait":"0",
+ "spark.sql.execution.sortBeforeRepartition":"false",
+ "spark.sql.execution.arrow.pyspark.enabled":"true",
+ "spark.sql.execution.arrow.maxRecordsPerBatch":"100000",
+ "spark.sql.cache.serializer":"com.nvidia.spark.ParquetCachedBatchSerializer",
+ "spark.pyspark.python":"/usr/local/bin/python3.10",
+ "spark.pyspark.driver.python":"/usr/local/bin/python3.10",
+ "spark.dynamicAllocation.enabled":"false",
+ "spark.driver.memory":"20g",
+ "spark.rpc.message.maxSize":"512",
+ "spark.executorEnv.CUPY_CACHE_DIR":"/tmp/.cupy",
+ "spark.executorEnv.NCCL_DEBUG":"INFO",
+ "spark.executorEnv.NCCL_SOCKET_IFNAME":"ens"
+ }
+ },
+ {
+ "Classification":"capacity-scheduler",
+ "Properties":{
+ "yarn.scheduler.capacity.resource-calculator":"org.apache.hadoop.yarn.util.resource.DominantResourceCalculator"
+ }
+ }
+]
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init-config.json b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init-config.json
new file mode 100644
index 000000000..bb780e62e
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init-config.json
@@ -0,0 +1,62 @@
+[
+ {
+ "Classification":"yarn-site",
+ "Properties":{
+ "yarn.nodemanager.resource-plugins":"yarn.io/gpu",
+ "yarn.resource-types":"yarn.io/gpu",
+ "yarn.nodemanager.resource-plugins.gpu.allowed-gpu-devices":"auto",
+ "yarn.nodemanager.resource-plugins.gpu.path-to-discovery-executables":"/usr/bin",
+ "yarn.nodemanager.linux-container-executor.cgroups.mount":"true",
+ "yarn.nodemanager.linux-container-executor.cgroups.mount-path":"/spark-rapids-cgroup",
+ "yarn.nodemanager.linux-container-executor.cgroups.hierarchy":"yarn",
+ "yarn.nodemanager.container-executor.class":"org.apache.hadoop.yarn.server.nodemanager.LinuxContainerExecutor"
+ }
+ },
+ {
+ "Classification":"container-executor",
+ "Properties":{
+
+ },
+ "Configurations":[
+ {
+ "Classification":"gpu",
+ "Properties":{
+ "module.enabled":"true"
+ }
+ },
+ {
+ "Classification":"cgroups",
+ "Properties":{
+ "root":"/spark-rapids-cgroup",
+ "yarn-hierarchy":"yarn"
+ }
+ }
+ ]
+ },
+ {
+ "Classification":"spark-defaults",
+ "Properties":{
+ "spark.sql.sources.useV1SourceList":"",
+ "spark.executor.resource.gpu.discoveryScript":"/usr/lib/spark/scripts/gpu/getGpusResources.sh",
+ "spark.executor.extraLibraryPath":"/usr/local/cuda/targets/x86_64-linux/lib:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/cuda/compat/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/lib/hadoop/lib/native:/usr/lib/hadoop-lzo/lib/native:/docker/usr/lib/hadoop/lib/native:/docker/usr/lib/hadoop-lzo/lib/native",
+ "spark.executor.resource.gpu.amount":"1",
+ "spark.executor.cores":"8",
+ "spark.task.resource.gpu.amount":"0.125",
+ "spark.python.worker.reuse":"true",
+ "spark.sql.execution.arrow.pyspark.enabled":"true",
+ "spark.pyspark.python":"/usr/local/bin/python3.10",
+ "spark.pyspark.driver.python":"/usr/local/bin/python3.10",
+ "spark.dynamicAllocation.enabled":"false",
+ "spark.driver.memory":"20g",
+ "spark.executorEnv.CUPY_CACHE_DIR":"/tmp/.cupy",
+ "spark.executorEnv.NCCL_DEBUG":"INFO",
+ "spark.executorEnv.NCCL_SOCKET_IFNAME":"ens"
+ }
+ },
+ {
+ "Classification":"capacity-scheduler",
+ "Properties":{
+ "yarn.scheduler.capacity.resource-calculator":"org.apache.hadoop.yarn.util.resource.DominantResourceCalculator"
+ }
+ }
+]
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init_spark_dl_tf.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init_spark_dl_tf.sh
new file mode 100644
index 000000000..0a7c62611
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init_spark_dl_tf.sh
@@ -0,0 +1,48 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+set -ex
+
+sudo mkdir -p /spark-rapids-cgroup/devices
+sudo mount -t cgroup -o devices cgroupv1-devices /spark-rapids-cgroup/devices
+sudo chmod a+rwx -R /spark-rapids-cgroup
+
+sudo yum update -y
+sudo yum install -y gcc bzip2-devel libffi-devel tar gzip wget make
+sudo yum install -y mysql-devel --skip-broken
+sudo bash -c "wget https://www.python.org/ftp/python/3.10.9/Python-3.10.9.tgz && \
+tar xzf Python-3.10.9.tgz && cd Python-3.10.9 && \
+./configure --enable-optimizations && make altinstall"
+
+RAPIDS_VERSION=24.12.0
+
+sudo /usr/local/bin/pip3.10 install --upgrade pip
+
+# install cudf
+sudo /usr/local/bin/pip3.10 install --no-cache-dir cudf-cu12 --extra-index-url=https://pypi.nvidia.com --verbose
+
+cat < temp_requirements.txt
+datasets==3.*
+transformers
+urllib3<2
+nvidia-pytriton
+EOF
+
+sudo /usr/local/bin/pip3.10 install --upgrade --force-reinstall -r temp_requirements.txt
+rm temp_requirements.txt
+
+sudo /usr/local/bin/pip3.10 list
+
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init_spark_dl_torch.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init_spark_dl_torch.sh
new file mode 100644
index 000000000..1657a0cde
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/init_spark_dl_torch.sh
@@ -0,0 +1,54 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+set -ex
+
+sudo mkdir -p /spark-rapids-cgroup/devices
+sudo mount -t cgroup -o devices cgroupv1-devices /spark-rapids-cgroup/devices
+sudo chmod a+rwx -R /spark-rapids-cgroup
+
+sudo yum update -y
+sudo yum install -y gcc bzip2-devel libffi-devel tar gzip wget make
+sudo yum install -y mysql-devel --skip-broken
+sudo bash -c "wget https://www.python.org/ftp/python/3.10.9/Python-3.10.9.tgz && \
+tar xzf Python-3.10.9.tgz && cd Python-3.10.9 && \
+./configure --enable-optimizations && make altinstall"
+
+RAPIDS_VERSION=24.12.0
+
+sudo /usr/local/bin/pip3.10 install --upgrade pip
+
+# install cudf and cuml
+sudo /usr/local/bin/pip3.10 install --no-cache-dir cudf-cu12 --extra-index-url=https://pypi.nvidia.com --verbose
+
+cat < temp_requirements.txt
+datasets==3.*
+transformers
+urllib3<2
+nvidia-pytriton
+torch
+torchvision --extra-index-url https://download.pytorch.org/whl/cu121
+torch-tensorrt
+tensorrt --extra-index-url https://download.pytorch.org/whl/cu121
+sentence_transformers
+sentencepiece
+nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com
+EOF
+
+sudo /usr/local/bin/pip3.10 install --upgrade --force-reinstall -r temp_requirements.txt
+rm temp_requirements.txt
+
+sudo /usr/local/bin/pip3.10 list
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/start_cluster.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/start_cluster.sh
new file mode 100644
index 000000000..7469f0809
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/aws-emr/setup/start_cluster.sh
@@ -0,0 +1,73 @@
+#!/bin/bash -ex
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+set -o pipefail
+
+cluster_type=${1:-gpu}
+
+# configure arguments
+if [[ -z ${SUBNET_ID} ]]; then
+ echo "Please export SUBNET_ID per README.md"
+ exit 1
+fi
+
+if [[ -z ${BENCHMARK_HOME} ]]; then
+ echo "Please export BENCHMARK_HOME per README.md"
+ exit 1
+fi
+
+if [[ -z ${KEYPAIR} ]]; then
+ echo "Please export KEYPAIR per README.md"
+ exit 1
+fi
+
+cluster_name=spark-rapids-ml-${cluster_type}
+cur_dir=$(pwd)
+
+if [[ ${cluster_type} == "gpu" ]]; then
+ core_type=g5.2xlarge
+ config_json="file://${cur_dir}/../../../notebooks/aws-emr/init-configurations.json"
+ bootstrap_actions="--bootstrap-actions Name='Spark Rapids ML Bootstrap action',Path=s3://${BENCHMARK_HOME}/init-bootstrap-action.sh"
+elif [[ ${cluster_type} == "cpu" ]]; then
+ core_type=m6gd.2xlarge
+ config_json="file://${cur_dir}/cpu-init-configurations.json"
+ bootstrap_actions=""
+else
+ echo "unknown cluster type ${cluster_type}"
+ echo "usage: ./${script_name} cpu|gpu"
+ exit 1
+fi
+
+start_cmd="aws emr create-cluster \
+--name ${cluster_name} \
+--release-label emr-7.3.0 \
+--applications Name=Hadoop Name=Spark \
+--service-role EMR_DefaultRole \
+--log-uri s3://${BENCHMARK_HOME}/logs \
+--ec2-attributes KeyName=$(basename ${KEYPAIR} | sed -e 's/\.pem//g' ),SubnetId=${SUBNET_ID},InstanceProfile=EMR_EC2_DefaultRole \
+--ebs-root-volume-size=32 \
+--instance-groups InstanceGroupType=MASTER,InstanceCount=1,InstanceType=m4.2xlarge \
+ InstanceGroupType=CORE,InstanceCount=3,InstanceType=${core_type} \
+--configurations ${config_json} $bootstrap_actions
+"
+
+CLUSTER_ID=$( eval ${start_cmd} | tee /dev/tty | grep "ClusterId" | grep -o 'j-[0-9|A-Z]*')
+aws emr put-auto-termination-policy --cluster-id ${CLUSTER_ID} --auto-termination-policy IdleTimeout=1800
+echo "waiting for cluster ${CLUSTER_ID} to start ... " 1>&2
+
+aws emr wait cluster-running --cluster-id $CLUSTER_ID
+
+echo "cluster started." 1>&2
+echo $CLUSTER_ID
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/README.md b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/README.md
new file mode 100644
index 000000000..2ca135c31
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/README.md
@@ -0,0 +1,41 @@
+# Spark DL Inference on Databricks
+
+## Setup
+
+1. Install the latest [databricks-cli](https://docs.databricks.com/en/dev-tools/cli/tutorial.html) and configure for your workspace.
+
+2. Specify the path to the notebook and init script (_torch or _tf), and the destination filepaths on Databricks.
+ As an example for a PyTorch notebook:
+ ```shell
+ export NOTEBOOK_SRC=/path/to/notebook_torch.ipynb
+ export NOTEBOOK_DEST=/Users/someone@example.com/spark-dl/notebook_torch.ipynb
+
+ export INIT_SRC=/path/to/setup/init_spark_dl_torch.sh
+ export INIT_DEST=/Users/someone@example.com/spark-dl/init_spark_dl_torch.sh
+ ```
+
+3. Copy the files to the Databricks Workspace:
+ ```shell
+ databricks workspace import $INIT_DEST --format AUTO --file $INIT_SRC
+ databricks workspace import $NOTEBOOK_DEST --format JUPYTER --file $NOTEBOOK_SRC
+ ```
+
+4. Launch the cluster with the provided script (note that the script specifies **Azure instances** by default; change as needed):
+ ```shell
+ export CLUSTER_NAME=spark-dl-inference-torch
+ cd setup
+ chmod +x start_cluster.sh
+ ./start_cluster.sh
+ ```
+
+ OR, start the cluster from the Databricks UI:
+
+ - Go to `Compute > Create compute` and set the desired cluster settings.
+ - Integration with Triton inference server uses stage-level scheduling (Spark>=3.4.0). Make sure to:
+ - use a cluster with GPU resources
+ - set a value for `spark.executor.cores`
+ - ensure that `spark.executor.resource.gpu.amount` = 1
+ - Under `Advanced Options > Init Scripts`, upload the init script from your workspace.
+ - For Tensorflow notebooks, we recommend setting the environment variable `TF_GPU_ALLOCATOR=cuda_malloc_async` (especially for Huggingface LLM models), which enables the CUDA driver to implicity release unused memory from the pool.
+
+5. Navigate to the notebook in your workspace and attach it to the cluster.
\ No newline at end of file
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl_tf.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl_tf.sh
new file mode 100644
index 000000000..237d20f35
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl_tf.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+
+set -x
+
+# install requirements
+sudo /databricks/python3/bin/pip3 install --upgrade pip
+
+cat < temp_requirements.txt
+datasets==3.*
+transformers
+urllib3<2
+nvidia-pytriton
+EOF
+
+sudo /databricks/python3/bin/pip3 install --upgrade --force-reinstall -r temp_requirements.txt
+rm temp_requirements.txt
+
+set +x
\ No newline at end of file
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl_torch.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl_torch.sh
new file mode 100644
index 000000000..669b373b3
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/init_spark_dl_torch.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+
+set -x
+
+# install requirements
+sudo /databricks/python3/bin/pip3 install --upgrade pip
+
+cat < temp_requirements.txt
+datasets==3.*
+transformers
+urllib3<2
+nvidia-pytriton
+torch
+torchvision --extra-index-url https://download.pytorch.org/whl/cu121
+torch-tensorrt
+tensorrt --extra-index-url https://download.pytorch.org/whl/cu121
+sentence_transformers
+sentencepiece
+nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com
+EOF
+
+sudo /databricks/python3/bin/pip3 install --upgrade --force-reinstall -r temp_requirements.txt
+rm temp_requirements.txt
+
+set +x
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/start_cluster.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/start_cluster.sh
new file mode 100755
index 000000000..d9e4dc50a
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/databricks/setup/start_cluster.sh
@@ -0,0 +1,46 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+
+# configure arguments
+if [[ -z ${INIT_DEST} ]]; then
+ echo "Please make sure INIT_DEST is exported per README.md"
+ exit 1
+fi
+
+if [[ -z ${CLUSTER_NAME} ]]; then
+ echo "Please make sure CLUSTER_NAME is exported per README.md"
+ exit 1
+fi
+
+json_config=$(cat < require user inputs.
+
+#### Setup GCloud CLI
+
+1. Install the latest [gcloud-cli](https://cloud.google.com/sdk/docs/install) and initialize with `gcloud init`.
+
+2. Configure the following settings:
+ ```shell
+ export PROJECT=
+ export DATAPROC_REGION=
+ export COMPUTE_REGION=
+ export COMPUTE_ZONE=
+
+ gcloud config set project ${PROJECT}
+ gcloud config set dataproc/region ${DATAPROC_REGION}
+ gcloud config set compute/region ${COMPUTE_REGION}
+ gcloud config set compute/zone ${COMPUTE_ZONE}
+ ```
+
+#### Copy files to GCS
+
+3. Create a GCS bucket if you don't already have one:
+ ```shell
+ export GCS_BUCKET=
+
+ gcloud storage buckets create gs://${GCS_BUCKET}
+ ```
+
+4. Specify the local path to the notebook(s) and copy to the GCS bucket.
+ As an example for a torch notebook:
+ ```shell
+ export SPARK_DL_HOME=${GCS_BUCKET}/spark-dl
+
+ gcloud storage cp gs://${SPARK_DL_HOME}/notebooks/
+ ```
+ Repeat this step for any notebooks you wish to run. All notebooks under `gs://${SPARK_DL_HOME}/notebooks/` will be copied to the master node during initialization.
+
+#### Start cluster and run
+
+5. Specify the framework to use (torch or tf), which will determine what libraries to install on the cluster. For example:
+ ```shell
+ export FRAMEWORK=torch
+ ```
+ Run the cluster startup script. The script will also retrieve and use the [spark-rapids initialization script](https://github.com/GoogleCloudDataproc/initialization-actions/blob/master/spark-rapids/spark-rapids.sh) to setup GPU resources.
+ ```shell
+ cd setup
+ chmod +x start_cluster.sh
+ ./start_cluster.sh
+ ```
+ By default, the script creates a 4 node GPU cluster named `${USER}-spark-dl-inference-${FRAMEWORK}`.
+
+7. Browse to the Jupyter web UI:
+ - Go to `Dataproc` > `Clusters` > `(Cluster Name)` > `Web Interfaces` > `Jupyter/Lab`
+
+ Or, get the link by running this command (under httpPorts > Jupyter/Lab):
+ ```shell
+ gcloud dataproc clusters describe ${CLUSTER_NAME} --region=${COMPUTE_REGION}
+ ```
+
+8. Open and run the notebook interactively with the **Python 3 kernel**.
+The notebooks can be found under `Local Disk/spark-dl-notebooks` on the master node (folder icon on the top left > Local Disk).
\ No newline at end of file
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/init_spark_dl.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/init_spark_dl.sh
new file mode 100644
index 000000000..dfbae71f8
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/init_spark_dl.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+
+set -euxo pipefail
+
+function get_metadata_attribute() {
+ local -r attribute_name=$1
+ local -r default_value=$2
+ /usr/share/google/get_metadata_value "attributes/${attribute_name}" || echo -n "${default_value}"
+}
+
+SPARK_DL_HOME=$(get_metadata_attribute spark-dl-home UNSET)
+if [[ ${SPARK_DL_HOME} == "UNSET" ]]; then
+ echo "Please set --metadata spark-dl-home"
+ exit 1
+fi
+
+GCS_BUCKET=$(get_metadata_attribute gcs-bucket UNSET)
+if [[ ${GCS_BUCKET} == "UNSET" ]]; then
+ echo "Please set --metadata gcs-bucket"
+ exit 1
+fi
+
+REQUIREMENTS=$(get_metadata_attribute requirements UNSET)
+if [[ ${REQUIREMENTS} == "UNSET" ]]; then
+ echo "Please set --metadata requirements"
+ exit 1
+fi
+
+# mount gcs bucket as fuse
+export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
+echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
+curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
+sudo apt-get update
+sudo apt-get install -y fuse gcsfuse
+sudo mkdir -p /mnt/gcs
+gcsfuse -o allow_other --implicit-dirs ${GCS_BUCKET} /mnt/gcs
+sudo chmod -R 777 /mnt/gcs
+
+# install requirements
+pip install --upgrade pip
+echo "${REQUIREMENTS}" > temp_requirements.txt
+pip install --upgrade --force-reinstall -r temp_requirements.txt
+rm temp_requirements.txt
+
+# copy notebooks to master
+ROLE=$(/usr/share/google/get_metadata_value attributes/dataproc-role)
+if [[ "${ROLE}" == 'Master' ]]; then
+ if gsutil -q stat gs://${SPARK_DL_HOME}/notebooks/**; then
+ mkdir spark-dl-notebooks
+ gcloud storage cp -r gs://${SPARK_DL_HOME}/notebooks/* spark-dl-notebooks
+ else
+ echo "Failed to retrieve notebooks from gs://${SPARK_DL_HOME}/notebooks/"
+ exit 1
+ fi
+fi
+
+sudo chmod -R a+rw /home/
+sudo systemctl daemon-reload
\ No newline at end of file
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/start_cluster.sh b/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/start_cluster.sh
new file mode 100755
index 000000000..84880c91e
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/dataproc/setup/start_cluster.sh
@@ -0,0 +1,103 @@
+#!/bin/bash
+# Copyright (c) 2025, NVIDIA CORPORATION.
+
+# configure arguments
+if [[ -z ${GCS_BUCKET} ]]; then
+ echo "Please export GCS_BUCKET per README.md"
+ exit 1
+fi
+
+if [[ -z ${FRAMEWORK} ]]; then
+ echo "Please export FRAMEWORK as 'torch' or 'tf'"
+ exit 1
+fi
+
+if [[ -z ${COMPUTE_REGION} ]]; then
+ COMPUTE_REGION=$(gcloud config get-value compute/region)
+ if [[ -z ${COMPUTE_REGION} ]]; then
+ echo "Please export COMPUTE_REGION per README.md or set it in gcloud config."
+ exit 1
+ fi
+fi
+
+SPARK_DL_HOME=${SPARK_DL_HOME:-${GCS_BUCKET}/spark-dl}
+
+# copy init script to gcs
+gcloud storage cp init_spark_dl.sh gs://${SPARK_DL_HOME}/init/
+INIT_PATH=gs://${SPARK_DL_HOME}/init/init_spark_dl.sh
+
+# retrieve and upload spark-rapids initialization script to gcs
+curl -LO https://raw.githubusercontent.com/GoogleCloudDataproc/initialization-actions/master/spark-rapids/spark-rapids.sh
+# don't enable rapids plugin by default
+sed -i '/spark.plugins=com.nvidia.spark.SQLPlugin/d' spark-rapids.sh
+gcloud storage cp spark-rapids.sh gs://${SPARK_DL_HOME}/init/
+# rm spark-rapids.sh
+
+COMMON_REQUIREMENTS="numpy
+pandas
+matplotlib
+portalocker
+pyarrow
+pydot
+scikit-learn
+huggingface
+datasets==3.*
+transformers
+urllib3<2
+nvidia-pytriton"
+
+TORCH_REQUIREMENTS="${COMMON_REQUIREMENTS}
+torch
+torchvision --extra-index-url https://download.pytorch.org/whl/cu121
+torch-tensorrt
+tensorrt --extra-index-url https://download.pytorch.org/whl/cu121
+sentence_transformers
+sentencepiece
+nvidia-modelopt[all] --extra-index-url https://pypi.nvidia.com"
+
+TF_REQUIREMENTS="${COMMON_REQUIREMENTS}
+tensorflow[and-cuda]
+tf-keras"
+
+cluster_name=${USER}-spark-dl-inference-${FRAMEWORK}
+if [[ ${FRAMEWORK} == "torch" ]]; then
+ requirements=${TORCH_REQUIREMENTS}
+ echo "========================================================="
+ echo "Starting PyTorch cluster ${cluster_name}"
+ echo "========================================================="
+elif [[ ${FRAMEWORK} == "tf" ]]; then
+ requirements=${TF_REQUIREMENTS}
+ echo "========================================================="
+ echo "Starting Tensorflow cluster ${cluster_name}"
+ echo "========================================================="
+else
+ echo "Please export FRAMEWORK as torch or tf"
+ exit 1
+fi
+
+# start cluster if not already running
+if gcloud dataproc clusters list | grep -q "${cluster_name}"; then
+ echo "Cluster ${cluster_name} already exists."
+else
+ gcloud dataproc clusters create ${cluster_name} \
+ --image-version=2.2-ubuntu \
+ --region ${COMPUTE_REGION} \
+ --master-machine-type n1-standard-16 \
+ --num-workers 4 \
+ --worker-min-cpu-platform="Intel Skylake" \
+ --worker-machine-type n1-standard-16 \
+ --master-accelerator type=nvidia-tesla-t4,count=1 \
+ --worker-accelerator type=nvidia-tesla-t4,count=1 \
+ --initialization-actions gs://${SPARK_DL_HOME}/init/spark-rapids.sh,${INIT_PATH} \
+ --metadata gpu-driver-provider="NVIDIA" \
+ --metadata gcs-bucket=${GCS_BUCKET} \
+ --metadata spark-dl-home=${SPARK_DL_HOME} \
+ --metadata requirements="${requirements}" \
+ --worker-local-ssd-interface=NVME \
+ --optional-components=JUPYTER \
+ --bucket ${GCS_BUCKET} \
+ --enable-component-gateway \
+ --max-idle "60m" \
+ --subnet=default \
+ --no-shielded-secure-boot
+fi
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_tf.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_tf.ipynb
index 3105e0661..987d8b3e0 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_tf.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_tf.ipynb
@@ -5,9 +5,12 @@
"id": "777fc40d",
"metadata": {},
"source": [
+ "\n",
+ "\n",
"# PySpark Huggingface Inferencing\n",
- "## Conditional generation with Tensorflow\n",
+ "### Conditional generation with Tensorflow\n",
"\n",
+ "In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation. \n",
"From: https://huggingface.co/docs/transformers/model_doc/t5"
]
},
@@ -16,9 +19,7 @@
"id": "05c79ac4-bf25-421e-b55e-020d6d9e15d5",
"metadata": {},
"source": [
- "### Using TensorFlow\n",
- "Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) \n",
- "This notebook does not demonstrate inference with TensorRT, as [TF-TRT](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#tensorrt-10) does not yet support `tf=2.17.0`. See the `pytorch` notebooks for TensorRT demos."
+ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075)"
]
},
{
@@ -31,42 +32,28 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-11 00:16:59.451769: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
- "2024-10-11 00:16:59.459246: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
- "2024-10-11 00:16:59.467162: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
- "2024-10-11 00:16:59.469569: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
- "2024-10-11 00:16:59.475888: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "2025-01-06 20:51:22.130718: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2025-01-06 20:51:22.137798: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "2025-01-06 20:51:22.145743: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "2025-01-06 20:51:22.148175: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2025-01-06 20:51:22.154403: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2024-10-11 00:16:59.818338: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ "2025-01-06 20:51:22.582456: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"source": [
- "from transformers import AutoTokenizer, TFT5ForConditionalGeneration"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5346a20c",
- "metadata": {},
- "source": [
- "Enabling Huggingface tokenizer parallelism so that it is not automatically disabled with Python parallelism. See [this thread](https://github.com/huggingface/transformers/issues/5486) for more info. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "a1008e27",
- "metadata": {},
- "outputs": [],
- "source": [
+ "from transformers import AutoTokenizer, TFT5ForConditionalGeneration\n",
+ "\n",
+ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
+ "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
"import os\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"id": "275890d7",
"metadata": {},
"outputs": [
@@ -95,7 +82,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "2684fb41-9467-40c0-9d7e-a1cc867c5a3c",
"metadata": {},
"outputs": [
@@ -103,7 +90,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-11 00:17:00.886565: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46024 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+ "2025-01-06 20:51:23.810073: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 43408 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
"All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.\n",
"\n",
"All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.\n",
@@ -128,7 +115,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "6eb2dfdb-0ad3-4d0f-81a4-268d92c53759",
"metadata": {},
"outputs": [
@@ -137,26 +124,24 @@
"output_type": "stream",
"text": [
"WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n",
- "I0000 00:00:1728605822.106234 276792 service.cc:146] XLA service 0x7f53a8003630 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
- "I0000 00:00:1728605822.106259 276792 service.cc:154] StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n",
- "2024-10-11 00:17:02.108842: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
- "2024-10-11 00:17:02.117215: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n",
- "I0000 00:00:1728605822.137920 276792 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"
+ "I0000 00:00:1736196685.027406 2719456 service.cc:146] XLA service 0x788c940027b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:\n",
+ "I0000 00:00:1736196685.027425 2719456 service.cc:154] StreamExecutor device (0): NVIDIA RTX A6000, Compute Capability 8.6\n",
+ "2025-01-06 20:51:25.030311: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.\n",
+ "2025-01-06 20:51:25.039401: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:531] Loaded cuDNN version 8907\n",
+ "I0000 00:00:1736196685.062214 2719456 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n"
]
}
],
"source": [
- "input_ids = tokenizer(input_sequences, \n",
- " padding=\"longest\", \n",
- " max_length=512,\n",
- " truncation=True,\n",
- " return_tensors=\"tf\").input_ids\n",
- "outputs = model.generate(input_ids, max_length=20)"
+ "inputs = tokenizer(input_sequences, \n",
+ " padding=True,\n",
+ " return_tensors=\"tf\")\n",
+ "outputs = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_length=128)"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"id": "720158d4-e0e0-4904-b096-e5aede756afd",
"metadata": {},
"outputs": [
@@ -168,7 +153,7 @@
" 'HuggingFace ist ein Unternehmen']"
]
},
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -177,63 +162,76 @@
"[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "546eabe0",
+ "metadata": {},
+ "source": [
+ "## PySpark"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "68121304-f1df-466e-9347-c9d2b36a9b3a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyspark.sql.types import *\n",
+ "from pyspark import SparkConf\n",
+ "from pyspark.sql import SparkSession\n",
+ "from pyspark.sql.functions import pandas_udf, col, struct\n",
+ "from pyspark.ml.functions import predict_batch_udf"
+ ]
+ },
{
"cell_type": "code",
"execution_count": 7,
- "id": "8d4b364b-13cb-48ea-a97a-ccfc9e408075",
+ "id": "2f6db1f0-7d68-4af7-8bd6-c9fa45906c61",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'tf'"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "model.framework"
+ "import json\n",
+ "import pandas as pd\n",
+ "import datasets\n",
+ "from datasets import load_dataset\n",
+ "datasets.disable_progress_bars()"
]
},
{
"cell_type": "markdown",
- "id": "546eabe0",
+ "id": "0d636975",
"metadata": {},
"source": [
- "## PySpark"
+ "Check the cluster environment to handle any platform-specific Spark configurations."
]
},
{
"cell_type": "code",
"execution_count": 8,
- "id": "2f6db1f0-7d68-4af7-8bd6-c9fa45906c61",
+ "id": "ca351245",
"metadata": {},
"outputs": [],
"source": [
- "import os\n",
- "from pathlib import Path\n",
- "from datasets import load_dataset"
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
]
},
{
- "cell_type": "code",
- "execution_count": 9,
- "id": "68121304-f1df-466e-9347-c9d2b36a9b3a",
+ "cell_type": "markdown",
+ "id": "d3199f8b",
"metadata": {},
- "outputs": [],
"source": [
- "from pyspark.sql.types import *\n",
- "from pyspark.sql import SparkSession\n",
- "from pyspark import SparkConf\n",
- "import socket"
+ "#### Create Spark Session\n",
+ "\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"id": "6279a849",
"metadata": {},
"outputs": [
@@ -241,108 +239,134 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/11 00:17:03 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
- "24/10/11 00:17:03 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "25/01/06 20:51:26 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/06 20:51:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
- "24/10/11 00:17:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ "25/01/06 20:51:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
"source": [
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
- "hostname = socket.gethostname()\n",
- "\n",
"conf = SparkConf()\n",
+ "\n",
"if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"512\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " \n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " source = \"/usr/lib/x86_64-linux-gnu/libstdc++.so.6\"\n",
+ " target = f\"{conda_env}/lib/libstdc++.so.6\"\n",
+ " try:\n",
+ " if os.path.islink(target) or os.path.exists(target):\n",
+ " os.remove(target)\n",
+ " os.symlink(source, target)\n",
+ " except OSError as e:\n",
+ " print(f\"Error creating symlink: {e}\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
"spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
"sc = spark.sparkContext"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "7f311650",
+ "metadata": {},
+ "source": [
+ "Load the IMBD Movie Reviews dataset from Huggingface."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"id": "b8453111-d068-49bb-ab91-8ae3d8bcdb7a",
"metadata": {},
"outputs": [],
"source": [
- "# load IMDB reviews (test) dataset\n",
- "data = load_dataset(\"imdb\", split=\"test\")"
+ "dataset = load_dataset(\"imdb\", split=\"test\")\n",
+ "dataset = dataset.to_pandas().drop(columns=\"label\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6fd5b472-47e8-4804-9907-772793fedb2b",
+ "metadata": {},
+ "source": [
+ "### Create PySpark DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 12,
- "id": "7ad01d4a",
+ "execution_count": 11,
+ "id": "d24d9404-0269-476e-a9dd-1842667c915a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "25000"
+ "StructType([StructField('text', StringType(), True)])"
]
},
- "execution_count": 12,
+ "execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "lines = []\n",
- "for example in data:\n",
- " lines.append([example[\"text\"].split(\".\")[0]])\n",
- "\n",
- "len(lines)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6fd5b472-47e8-4804-9907-772793fedb2b",
- "metadata": {},
- "source": [
- "### Create PySpark DataFrame"
+ "df = spark.createDataFrame(dataset).repartition(8)\n",
+ "df.schema"
]
},
{
"cell_type": "code",
- "execution_count": 13,
- "id": "d24d9404-0269-476e-a9dd-1842667c915a",
+ "execution_count": 12,
+ "id": "c76314b7",
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ },
{
"data": {
"text/plain": [
- "StructType([StructField('lines', StringType(), True)])"
+ "25000"
]
},
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "df = spark.createDataFrame(lines, ['lines']).repartition(8)\n",
- "df.schema"
+ "df.count()"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"id": "4384c762-1f79-4f60-876c-94b1f552e8fb",
"metadata": {},
"outputs": [
@@ -350,16 +374,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "25/01/06 20:51:33 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
]
},
{
"data": {
"text/plain": [
- "[Row(lines='(Some Spoilers) Dull as dishwater slasher flick that has this deranged homeless man Harry, Darwyn Swalve, out murdering real-estate agent all over the city of L')]"
+ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
]
},
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
@@ -378,67 +402,44 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 14,
"id": "e7eec8ec-4126-4890-b957-025809fad67d",
"metadata": {},
- "outputs": [],
- "source": [
- "df.write.mode(\"overwrite\").parquet(\"imdb_test\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "304e1fc8-42a3-47dd-b3c0-47efd5be1040",
- "metadata": {},
- "source": [
- "### Check arrow memory configuration"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "20554ea5-01be-4a30-8607-db5d87786fec",
- "metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 20:51:33 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ }
+ ],
"source": [
- "if int(spark.conf.get(\"spark.sql.execution.arrow.maxRecordsPerBatch\")) > 512:\n",
- " print(\"Decreasing `spark.sql.execution.arrow.maxRecordsPerBatch` to ensure the vectorized reader won't run out of memory\")\n",
- " spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"512\")\n",
- "assert len(df.head()) > 0, \"`df` should not be empty\""
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
]
},
{
"cell_type": "markdown",
- "id": "06a4ecab-c9d9-466f-ba49-902ad1fd5488",
+ "id": "078425e1",
"metadata": {},
"source": [
- "## Inference using Spark DL API\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "e7a00479-1347-4de8-8431-faa77f8cdf4c",
- "metadata": {
- "tags": []
- },
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, pandas_udf, struct\n",
- "from pyspark.sql.types import StringType"
+ "#### Load and preprocess DataFrame\n",
+ "\n",
+ "Define our preprocess function. We'll take the first sentence from each sample as our input for translation."
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 15,
"id": "b9a0889a-35b4-493a-8197-1146fc7efd53",
"metadata": {},
"outputs": [],
"source": [
- "# only use first sentence and add prefix for conditional generation\n",
"def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
" @pandas_udf(\"string\")\n",
" def _preprocess(text: pd.Series) -> pd.Series:\n",
@@ -448,7 +449,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 16,
"id": "c483e4d4-9ab1-416f-a766-694e17490fd3",
"metadata": {},
"outputs": [
@@ -456,130 +457,108 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| lines|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| I am a big fan of The ABC Movies of the Week genre|\n",
- "|In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Full House\" and the long-defunc...|\n",
- "|When The Spirits Within was released, all you heard from Final Fantasy fans was how awful the movie was because it di...|\n",
- "| I like to think of myself as a bad movie connoisseur|\n",
- "|This film did well at the box office, and the producers of this mess thought the stars had such good chemistry in thi...|\n",
- "|Following the pleasingly atmospheric original and the amusingly silly second one, this incredibly dull, slow, and une...|\n",
- "| I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| I have read all of the reviews for this direct to video movie|\n",
- "|Yes, it was an awful movie, but there was a song near the beginning of the movie, I think, called \"I got a Woody\" or ...|\n",
- "| This was the most uninteresting horror flick I have seen to date|\n",
- "|I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'The French Connection\", but it...|\n",
- "|Heart of Darkness Movie Review Could a book that is well known for its eloquent wording and complicated concepts ever...|\n",
- "| A bad movie ABOUT a bad movie|\n",
- "|Apart from the fact that this film was made ( I suppose it seemed a good idea at the time considering BOTTOM was so p...|\n",
- "|Watching this movie, you just have to ask: What were they thinking? There are so many noticeably bad parts of this mo...|\n",
- "| OK, lets start with the best|\n",
- "| Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 years|\n",
- "| C|\n",
- "| Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| text|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n",
+ "|There were two things I hated about WASTED : The directing and the script . I know I`m opening my...|\n",
+ "|I'm rather surprised that anybody found this film touching or moving.
The basic premis...|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n",
+ "|This movie has been done before. It is basically a unoriginal combo of \"Napoleon Dynamite\" and \"S...|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get involved in such mindles...|\n",
+ "|There is not one character on this sitcom with any redeeming qualities. They are all self-centere...|\n",
+ "|Tommy Lee Jones was the best Woodroe and no one can play Woodroe F. Call better than he. Not only...|\n",
+ "|My wife rented this movie and then conveniently never got to see it. If I ever want to torture he...|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n",
+ "|you will likely be sorely disappointed by this sequel that's not a sequel.AWIL is a classic.but t...|\n",
+ "|If I was British, I would be embarrassed by this portrayal of incompetence. A protection agent of...|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n",
+ "|This show is like watching someone who is in training to someday host a show. There are some good...|\n",
+ "|Sigh. I'm baffled when I see a short like this get attention and assignments and whatnot. I saw t...|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
- },
- {
- "data": {
- "text/plain": [
- "100"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
- "# only use first N examples, since this is slow\n",
- "df = spark.read.parquet(\"imdb_test\").limit(100)\n",
- "df.show(truncate=120)\n",
- "df.count()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "id": "831bc52c-a5c6-4c29-a6da-0566b5167773",
- "metadata": {},
- "outputs": [],
- "source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df1 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to German: \")).select(\"input\").limit(100).cache()"
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df.show(truncate=100)"
]
},
{
- "cell_type": "code",
- "execution_count": 21,
- "id": "46dac59c-5a54-4576-91e0-279c8b375b95",
+ "cell_type": "markdown",
+ "id": "a9f8e538",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "100"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
"source": [
- "df1.count()"
+ "Append a prefix to tell the model to translate English to French:"
]
},
{
"cell_type": "code",
- "execution_count": 22,
- "id": "fef1d846-5852-4762-8527-602f32c0d7cd",
+ "execution_count": 17,
+ "id": "831bc52c-a5c6-4c29-a6da-0566b5167773",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to German: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to German: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to German: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to German: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to German: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to German: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to German: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to German: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to German: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to German: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to German: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to German: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to German: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to German: OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to German: C|\n",
- "| Translate English to German: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| input|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|translate English to French: Doesn't anyone bother to check where this kind of sludge comes from ...|\n",
+ "|translate English to French: There were two things I hated about WASTED : The directing and the s...|\n",
+ "| translate English to French: I'm rather surprised that anybody found this film touching or moving|\n",
+ "|translate English to French: Cultural Vandalism Is the new Hallmark production of Gulliver's Trav...|\n",
+ "|translate English to French: I was at Wrestlemania VI in Toronto as a 10 year old, and the event ...|\n",
+ "| translate English to French: This movie has been done before|\n",
+ "|translate English to French: [ as a new resolution for this year 2005, i decide to write a commen...|\n",
+ "|translate English to French: This movie is over hyped!! I am sad to say that I manage to watch th...|\n",
+ "|translate English to French: This show had a promising start as sort of the opposite of 'Oceans 1...|\n",
+ "|translate English to French: MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors...|\n",
+ "| translate English to French: There is not one character on this sitcom with any redeeming qualities|\n",
+ "| translate English to French: Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n",
+ "| translate English to French: My wife rented this movie and then conveniently never got to see it|\n",
+ "|translate English to French: This is one of those star-filled over-the-top comedies that could a)...|\n",
+ "|translate English to French: This excruciatingly boring and unfunny movie made me think that Chap...|\n",
+ "|translate English to French: you will likely be sorely disappointed by this sequel that's not a s...|\n",
+ "|translate English to French: If I was British, I would be embarrassed by this portrayal of incomp...|\n",
+ "|translate English to French: One of those movies in which there are no big twists whatsoever and ...|\n",
+ "|translate English to French: This show is like watching someone who is in training to someday hos...|\n",
+ "| translate English to French: Sigh|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
- "df1.show(truncate=120)"
+ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()\n",
+ "input_df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ec53a65c",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 18,
"id": "e7ae69d3-70c2-4765-928f-c96a7ba59829",
"metadata": {},
"outputs": [],
@@ -590,6 +569,7 @@
" from transformers import TFT5ForConditionalGeneration, AutoTokenizer\n",
"\n",
" # Enable GPU memory growth\n",
+ " print(\"initializing model\")\n",
" gpus = tf.config.experimental.list_physical_devices('GPU')\n",
" if gpus:\n",
" try:\n",
@@ -602,15 +582,15 @@
" tokenizer = AutoTokenizer.from_pretrained(\"google-t5/t5-small\")\n",
"\n",
" def predict(inputs):\n",
- " flattened = np.squeeze(inputs).tolist() # convert 2d numpy array of string into flattened python list\n",
- " input_ids = tokenizer(flattened, \n",
- " padding=\"longest\", \n",
- " max_length=512,\n",
- " return_tensors=\"tf\").input_ids\n",
- " output_ids = model.generate(input_ids, max_length=20)\n",
- " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])\n",
+ " flattened = np.squeeze(inputs).tolist()\n",
+ " inputs = tokenizer(flattened, \n",
+ " padding=True, \n",
+ " return_tensors=\"tf\")\n",
+ " outputs = model.generate(input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_length=128)\n",
+ " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\n",
" print(\"predict: {}\".format(len(flattened)))\n",
- "\n",
" return string_outputs\n",
" \n",
" return predict"
@@ -618,19 +598,19 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 19,
"id": "36684f59-d947-43f8-a2e8-c7a423764e88",
"metadata": {},
"outputs": [],
"source": [
"generate = predict_batch_udf(predict_batch_fn,\n",
" return_type=StringType(),\n",
- " batch_size=10)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 20,
"id": "6a01c855-8fa1-4765-a3a5-2c9dd872df10",
"metadata": {},
"outputs": [
@@ -638,15 +618,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 21:> (0 + 1) / 1]\r"
+ "[Stage 24:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 9.39 ms, sys: 2.14 ms, total: 11.5 ms\n",
- "Wall time: 11.4 s\n"
+ "CPU times: user 13.4 ms, sys: 6.1 ms, total: 19.5 ms\n",
+ "Wall time: 17.8 s\n"
]
},
{
@@ -660,13 +640,13 @@
"source": [
"%%time\n",
"# first pass caches model/fn\n",
- "preds = df1.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 21,
"id": "d912d4b0-cd0b-44ea-859a-b23455cc2700",
"metadata": {},
"outputs": [
@@ -674,15 +654,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 23:> (0 + 1) / 1]\r"
+ "[Stage 27:=====================> (3 + 5) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.62 ms, sys: 4.01 ms, total: 7.64 ms\n",
- "Wall time: 8.53 s\n"
+ "CPU times: user 9.07 ms, sys: 4.8 ms, total: 13.9 ms\n",
+ "Wall time: 11.5 s\n"
]
},
{
@@ -695,13 +675,13 @@
],
"source": [
"%%time\n",
- "preds = df1.withColumn(\"preds\", generate(\"input\"))\n",
+ "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 22,
"id": "5fe3d88b-30f7-468f-8db8-1f4118d0f26c",
"metadata": {},
"outputs": [
@@ -709,15 +689,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 25:> (0 + 1) / 1]\r"
+ "[Stage 30:=====================> (3 + 5) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 5.37 ms, sys: 2.51 ms, total: 7.88 ms\n",
- "Wall time: 8.52 s\n"
+ "CPU times: user 5.94 ms, sys: 4.97 ms, total: 10.9 ms\n",
+ "Wall time: 11.6 s\n"
]
},
{
@@ -730,13 +710,13 @@
],
"source": [
"%%time\n",
- "preds = df1.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 23,
"id": "4ad9b365-4b9a-438e-8fdf-47da55cb1cf4",
"metadata": {},
"outputs": [
@@ -744,37 +724,37 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 27:> (0 + 1) / 1]\r"
+ "[Stage 33:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to German: This is so overly clichéd yo...| Das ist so übertrieben klischeehaft, dass Sie es nach den|\n",
- "|Translate English to German: I am a big fan of The ABC Mo...| Ich bin ein großer Fan von The ABC Movies of the Week|\n",
- "|Translate English to German: In the early 1990's \"Step-by...| Anfang der 1990er Jahre kam \"Step-by-Step\" als müh|\n",
- "|Translate English to German: When The Spirits Within was ...|Als The Spirits Within veröffentlicht wurde, hörten Sie v...|\n",
- "|Translate English to German: I like to think of myself as...| Ich halte mich gerne als schlechter Filmliebhaber|\n",
- "|Translate English to German: This film did well at the bo...|Dieser Film hat sich gut an der Boxoffice ereignet, und d...|\n",
- "|Translate English to German: Following the pleasingly atm...|Nach dem erfreulich stimmungsvollen Original und dem amüsant|\n",
- "|Translate English to German: I like CKY and Viva La Bam, ...| Ich mag CKY und Viva La Bam, also konnte ich mich nicht|\n",
- "|Translate English to German: I have read all of the revie...| Ich habe alle Rezensionen zu diesem direkten Film gelesen.|\n",
- "|Translate English to German: Yes, it was an awful movie, ...| Ja, es war ein schrecklicher Film, aber es gab|\n",
- "|Translate English to German: This was the most uninterest...|Dies war der größte Horrorfilm, den ich bisher gesehen habe.|\n",
- "|Translate English to German: I don't know if this excepti...|Ich weiß nicht, ob dieser außergewöhnlich langweilige Fil...|\n",
- "|Translate English to German: Heart of Darkness Movie Revi...|Herz der Dunkelheit Film Review Kann ein Buch, das für se...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie| Ein schlechter Film ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that thi...| Dieser Film wurde zwar fertiggestellt, aber es schien mir |\n",
- "|Translate English to German: Watching this movie, you jus...|Wenn man diesen Film anschaut, muss man einfach fragen: W...|\n",
- "| Translate English to German: OK, lets start with the best| OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) ...| Anna Christie (Greta Garbo) kehrt nach 15 Jahren zurück,|\n",
- "| Translate English to German: C| C|\n",
- "|Translate English to German: Tom and Jerry are transporti...|Tom und Jerry transportieren Güter über Flugzeug nach Afrika|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to French: Doesn't anyone bot...|Ne s'ennuie-t-il pas de vérifier où viennent ce...|\n",
+ "|translate English to French: There were two thi...|Il y avait deux choses que j'ai hâte de voir : ...|\n",
+ "|translate English to French: I'm rather surpris...|Je suis plutôt surpris que quelqu'un ait trouvé...|\n",
+ "|translate English to French: Cultural Vandalism...|Vandalisme culturel La nouvelle production Hall...|\n",
+ "|translate English to French: I was at Wrestlema...|J'étais à Wrestlemania VI à Toronto en 10 ans, ...|\n",
+ "|translate English to French: This movie has bee...| Ce film a été réalisé avant|\n",
+ "|translate English to French: [ as a new resolut...|[ en tant que nouvelle résolution pour cette an...|\n",
+ "|translate English to French: This movie is over...|Je suis triste de dire que je parviens à regard...|\n",
+ "|translate English to French: This show had a pr...|Ce spectacle a eu un début prometteur en l'espè...|\n",
+ "|translate English to French: MINOR PLOT SPOILER...|br />br /> Comment ces acteurs talentueux ont-i...|\n",
+ "|translate English to French: There is not one c...|Il n'y a pas d'un personnage sur ce sitcom ayan...|\n",
+ "|translate English to French: Tommy Lee Jones wa...|Tommy Lee Jones était le meilleur Woodroe et pe...|\n",
+ "|translate English to French: My wife rented thi...|Ma femme a loué ce film et n'a jamais pu le voi...|\n",
+ "|translate English to French: This is one of tho...|C’est l’une des comédies en étoiles à l’étoile ...|\n",
+ "|translate English to French: This excruciatingl...|Ce film excruciant ennuyant et infaillible m’a ...|\n",
+ "|translate English to French: you will likely be...|Vous serez probablement très déçu par cette séq...|\n",
+ "|translate English to French: If I was British, ...|Si j'étais britannique, je seraitis embarrassé ...|\n",
+ "|translate English to French: One of those movie...|Un des films dans lesquels il n'y a pas de gros...|\n",
+ "|translate English to French: This show is like ...|Ce spectacle ressemble à l'observation d'une pe...|\n",
+ "| translate English to French: Sigh| Pesée|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
@@ -788,66 +768,22 @@
}
],
"source": [
- "preds.show(truncate=60)"
+ "preds.show(truncate=50)"
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 24,
"id": "1eb0c83b-d91b-4f8c-a5e7-c35f55c88108",
"metadata": {},
"outputs": [],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df2 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to French: \")).select(\"input\").limit(100).cache()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "id": "054f94fd-fe79-41e7-b1c7-6124083acc72",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to French: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to French: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to French: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to French: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to French: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to French: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to French: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to French: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to French: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to French: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to French: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to French: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to French: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to French: OK, lets start with the best|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to French: C|\n",
- "| Translate English to French: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
- }
- ],
- "source": [
- "df2.show(truncate=120)"
+ "input_df2 = df.select(preprocess(col(\"text\"), \"translate English to German: \").alias(\"input\")).cache()"
]
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 25,
"id": "6f6b70f9-188a-402b-9143-78a5788140e4",
"metadata": {},
"outputs": [
@@ -855,15 +791,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 33:> (0 + 1) / 1]\r"
+ "[Stage 36:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.9 ms, sys: 5.97 ms, total: 8.87 ms\n",
- "Wall time: 11.7 s\n"
+ "CPU times: user 9.54 ms, sys: 3.82 ms, total: 13.4 ms\n",
+ "Wall time: 14.5 s\n"
]
},
{
@@ -877,13 +813,13 @@
"source": [
"%%time\n",
"# first pass caches model/fn\n",
- "preds = df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "preds = input_df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
"result = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 26,
"id": "031a6a5e-7999-4653-b394-19ed478d8c96",
"metadata": {},
"outputs": [
@@ -891,15 +827,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 35:> (0 + 1) / 1]\r"
+ "[Stage 39:==============> (2 + 6) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 4.41 ms, sys: 1.59 ms, total: 5.99 ms\n",
- "Wall time: 8.23 s\n"
+ "CPU times: user 7.82 ms, sys: 5.49 ms, total: 13.3 ms\n",
+ "Wall time: 11.6 s\n"
]
},
{
@@ -912,13 +848,13 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(\"input\"))\n",
+ "preds = input_df2.withColumn(\"preds\", generate(\"input\"))\n",
"result = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 27,
"id": "229b6515-82f6-4e9c-90f0-a9c3cfb26301",
"metadata": {},
"outputs": [
@@ -926,15 +862,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 37:> (0 + 1) / 1]\r"
+ "[Stage 42:==============> (2 + 6) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 5.46 ms, sys: 1.17 ms, total: 6.63 ms\n",
- "Wall time: 8.08 s\n"
+ "CPU times: user 5.66 ms, sys: 8.1 ms, total: 13.8 ms\n",
+ "Wall time: 11.7 s\n"
]
},
{
@@ -947,13 +883,13 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "preds = input_df2.withColumn(\"preds\", generate(col(\"input\")))\n",
"result = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 28,
"id": "8be750ac-fa39-452e-bb4c-c2270bc2f70d",
"metadata": {},
"outputs": [
@@ -961,37 +897,37 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 39:> (0 + 1) / 1]\r"
+ "[Stage 45:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to French: This is so overly clichéd yo...| Vous ne pouvez pas en tirer d'un tel cliché|\n",
- "|Translate English to French: I am a big fan of The ABC Mo...| Je suis un grand fan du genre The ABC Movies of the Week|\n",
- "|Translate English to French: In the early 1990's \"Step-by...| Au début des années 1990, «Step-by-Step» a été une|\n",
- "|Translate English to French: When The Spirits Within was ...|Lorsque The Spirits Within a été publié, tout ce que vous...|\n",
- "|Translate English to French: I like to think of myself as...| Je me considère comme un mauvais réalisateur de films|\n",
- "|Translate English to French: This film did well at the bo...| Ce film a bien avancé à la salle de cinéma et|\n",
- "|Translate English to French: Following the pleasingly atm...|Après l'original agréablement atmosphérique et la seconde...|\n",
- "|Translate English to French: I like CKY and Viva La Bam, ...| Je m'aime CKY et Viva La Bam, |\n",
- "|Translate English to French: I have read all of the revie...| J'ai lu tous les commentaires pour ce film direct à vidéo|\n",
- "|Translate English to French: Yes, it was an awful movie, ...| Oui, c'était un film terrible, mais il y avait une chanson|\n",
- "|Translate English to French: This was the most uninterest...| Ce fut le plus inquiétant et le plus inquiétant d'h|\n",
- "|Translate English to French: I don't know if this excepti...|Je ne sais pas si ce film extrêmement tacheté était desti...|\n",
- "|Translate English to French: Heart of Darkness Movie Revi...| Un livre connu pour son éloquence et ses concepts complexes|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie| Un mauvais film ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that thi...|En plus du fait que ce film a été réalisé (je suppose quil s|\n",
- "|Translate English to French: Watching this movie, you jus...|Vous devez simplement vous demander : « Que pense-t-il? » Il|\n",
- "| Translate English to French: OK, lets start with the best| OK, s'il y a lieu de commencer par le meilleur|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) ...|Anna Christie (Greta Garbo) retourne pour voir son père C...|\n",
- "| Translate English to French: C| C|\n",
- "|Translate English to French: Tom and Jerry are transporti...| Tom et Jerry transportent des marchandises par avion en |\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to German: Doesn't anyone bot...|Warum hat man sich nicht angeschaut, woher der ...|\n",
+ "|translate English to German: There were two thi...|Es gab zwei Dinge, die ich hat an WASTED gehass...|\n",
+ "|translate English to German: I'm rather surpris...|Ich bin ziemlich überrascht, dass jemand diesen...|\n",
+ "|translate English to German: Cultural Vandalism...|Kultureller Vandalismus Ist die neue Hallmark-P...|\n",
+ "|translate English to German: I was at Wrestlema...|Ich war als 10 Jahre alt bei Wrestlemania VI in...|\n",
+ "|translate English to German: This movie has bee...| Dieser Film wurde bereits vorgenommen|\n",
+ "|translate English to German: [ as a new resolut...|[ als neue Entschließung für dieses Jahr 2005, ...|\n",
+ "|translate English to German: This movie is over...|Ich hoffe, dass ich die ersten 15 Minuten diese...|\n",
+ "|translate English to German: This show had a pr...|Diese Show hatte einen vielversprechenden Start...|\n",
+ "|translate English to German: MINOR PLOT SPOILER...|br />br />Wie haben sich so talentierte Schausp...|\n",
+ "|translate English to German: There is not one c...|Es gibt keinen Charakter auf dieser Seite mit i...|\n",
+ "|translate English to German: Tommy Lee Jones wa...|Tommy Lee Jones war der beste Woodroe und niema...|\n",
+ "|translate English to German: My wife rented thi...|Meine Frau hat diesen Film vermietet und dann b...|\n",
+ "|translate English to German: This is one of tho...|Dies ist eines der Sterne-gefüllten über-the-to...|\n",
+ "|translate English to German: This excruciatingl...|Dieser schreckliche langweilige und unfunnelnde...|\n",
+ "|translate English to German: you will likely be...|Sie werden wahrscheinlich ernsthaft enttäuscht ...|\n",
+ "|translate English to German: If I was British, ...|Wenn ich Britisch wäre, wäre ich beschämt über ...|\n",
+ "|translate English to German: One of those movie...|Einer der Filme, in denen es keine großen Drehu...|\n",
+ "|translate English to German: This show is like ...|Diese Show ist wie ein jemanden, der in Ausbild...|\n",
+ "| translate English to German: Sigh| Segnen|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
@@ -1005,416 +941,297 @@
}
],
"source": [
- "preds.show(truncate=60)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "bcabb2a8-3880-46ec-8e01-5a10f71fe83d",
- "metadata": {},
- "source": [
- "### Using Triton Inference Server\n",
- "\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment. "
+ "preds.show(truncate=50)"
]
},
{
"cell_type": "markdown",
- "id": "5d98fa52-7665-49bf-865a-feec86effe23",
+ "id": "f5803188",
"metadata": {},
"source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n huggingface-tf -c conda-forge python=3.10.0\n",
- "conda activate huggingface-tf\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
"\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 tensorflow[and-cuda] tf-keras transformers conda-pack \n",
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
"\n",
- "conda-pack # huggingface-tf.tar.gz\n",
- "```"
+ ""
]
},
{
"cell_type": "code",
- "execution_count": 35,
- "id": "b858cf85-82e6-41ef-905b-d8c5d6fea492",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 29,
+ "id": "6d09f972",
+ "metadata": {},
"outputs": [],
"source": [
- "import os"
+ "from functools import partial"
]
},
{
"cell_type": "code",
- "execution_count": 36,
- "id": "05ce7c77-d562-45e8-89bb-cd656aba5a5f",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 30,
+ "id": "afd00b7e-8150-4c95-a2e4-037e9c90f92a",
+ "metadata": {},
"outputs": [],
"source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/hf_generation_tf models\n",
+ "def triton_server(ports):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import tensorflow as tf\n",
+ " from transformers import TFT5ForConditionalGeneration, AutoTokenizer\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
"\n",
- "# add custom execution environment\n",
- "cp huggingface-tf.tar.gz models"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a552865c-5dad-4f25-8834-f41e253ac2f6",
- "metadata": {
- "tags": []
- },
- "source": [
- "#### Start Triton Server on each executor"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "id": "afd00b7e-8150-4c95-a2e4-037e9c90f92a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 37,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "huggingface_cache_dir = \"{}/.cache/huggingface\".format(os.path.expanduser('~'))\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
+ " print(f\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\")\n",
"\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
+ " # Enable GPU memory growth\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
" try:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " environment=[\n",
- " \"TRANSFORMERS_CACHE=/cache\"\n",
- " ],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"1G\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " huggingface_cache_dir: {\"bind\": \"/cache\", \"mode\": \"rw\"}\n",
- " }\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- " except Exception as e:\n",
- " print(\">>>> failed to start triton: {}\".format(e))\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " ready = False\n",
- " while not ready:\n",
- " try:\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " time.sleep(5)\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ " \n",
+ " tokenizer = AutoTokenizer.from_pretrained(\"google-t5/t5-small\")\n",
+ " model = TFT5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = np.squeeze(inputs[\"text\"]).tolist()\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
+ " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
+ " inputs = tokenizer(decoded_sentences,\n",
+ " padding=True,\n",
+ " return_tensors=\"tf\")\n",
+ " output_ids = model.generate(input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_length=128)\n",
+ " outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\n",
+ " return {\n",
+ " \"translations\": outputs,\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"ConditionalGeneration\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"translations\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=64,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
"\n",
- " return [True]\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
"\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for info.\"\n",
+ " raise TimeoutError(emsg)"
]
},
{
"cell_type": "markdown",
- "id": "528d2df6-49fc-4be7-a534-a087dfe31c84",
+ "id": "527da1b0",
"metadata": {},
"source": [
- "#### Run inference"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 38,
- "id": "1a997c33-5202-466d-8304-b8c30f32978f",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "import pandas as pd\n",
- "from functools import partial\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, pandas_udf, struct\n",
- "from pyspark.sql.types import StringType"
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
]
},
{
"cell_type": "code",
- "execution_count": 39,
- "id": "9dea1875-6b95-4fc0-926d-a625a441b33d",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 31,
+ "id": "388b6130",
+ "metadata": {},
"outputs": [],
"source": [
- "# only use first N examples, since this is slow\n",
- "df = spark.read.parquet(\"imdb_test\").limit(100).cache()"
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
]
},
{
- "cell_type": "code",
- "execution_count": 40,
- "id": "5d6c54e7-534d-406f-b8e6-fd592efd0ab2",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "id": "96b35b50",
+ "metadata": {},
"source": [
- "# only use first sentence and add prefix for conditional generation\n",
- "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
- " @pandas_udf(\"string\")\n",
- " def _preprocess(text: pd.Series) -> pd.Series:\n",
- " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n",
- " return _preprocess(text)"
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
]
},
{
"cell_type": "code",
- "execution_count": 41,
- "id": "dc1bbbe3-4232-49e5-80f6-99976524b73b",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 32,
+ "id": "7c4855ca",
+ "metadata": {},
"outputs": [],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df1 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to German: \")).select(\"input\").limit(100)"
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
]
},
{
"cell_type": "code",
- "execution_count": 42,
- "id": "5d10c61c-6102-4d19-8dd6-0c7b5b65343e",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 33,
+ "id": "3d522f30",
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to German: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to German: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to German: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to German: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to German: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to German: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to German: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to German: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to German: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to German: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to German: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to German: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to German: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to German: OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to German: C|\n",
- "| Translate English to German: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
]
}
],
"source": [
- "df1.show(truncate=120)"
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
]
},
{
- "cell_type": "code",
- "execution_count": 43,
- "id": "2e0907da-a5d9-4c3b-9db4-ce5e70ca9bb4",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "id": "f9cc80bf",
+ "metadata": {},
"source": [
- "def triton_fn(triton_uri, model_name):\n",
- " import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
- " return predict"
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
- "execution_count": 44,
- "id": "9308bdd7-6f67-484d-8b51-dd1e1b2960ba",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 34,
+ "id": "ead2f799",
+ "metadata": {},
"outputs": [],
"source": [
- "generate = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"hf_generation_tf\"),\n",
- " return_type=StringType(),\n",
- " input_tensor_shapes=[[1]],\n",
- " batch_size=100)"
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
]
},
{
"cell_type": "code",
- "execution_count": 45,
- "id": "38484ffd-370d-492b-8ca4-9eff9f242a9f",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 35,
+ "id": "3487a85d",
+ "metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 45:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 5.88 ms, sys: 3.96 ms, total: 9.84 ms\n",
- "Wall time: 2.66 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "Using ports [7000, 7001, 7002]\n"
]
}
],
"source": [
- "%%time\n",
- "# first pass caches model/fn\n",
- "preds = df1.withColumn(\"preds\", generate(struct(\"input\")))\n",
- "results = preds.collect()"
+ "model_name = \"ConditionalGeneration\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
]
},
{
"cell_type": "code",
- "execution_count": 46,
- "id": "ebcb6699-3ac2-4529-ab0f-fab0a5e792da",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 36,
+ "id": "c605ab40",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 47:> (0 + 1) / 1]\r"
+ "[Stage 46:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.82 ms, sys: 1.05 ms, total: 3.87 ms\n",
- "Wall time: 1.03 s\n"
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2757193\n",
+ "}\n"
]
},
{
@@ -1426,204 +1243,145 @@
}
],
"source": [
- "%%time\n",
- "preds = df1.withColumn(\"preds\", generate(\"input\"))\n",
- "results = preds.collect()"
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3f284eb3",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
]
},
{
"cell_type": "code",
- "execution_count": 47,
- "id": "e2ed18ad-d00b-472c-b2c3-047932f2105d",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 49:> (0 + 1) / 1]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 1.55 ms, sys: 2.49 ms, total: 4.03 ms\n",
- "Wall time: 967 ms\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
+ "execution_count": 37,
+ "id": "404c5091",
+ "metadata": {},
+ "outputs": [],
"source": [
- "%%time\n",
- "preds = df1.withColumn(\"preds\", generate(col(\"input\")))\n",
- "results = preds.collect()"
+ "url = f\"http://localhost:{ports[0]}\""
]
},
{
"cell_type": "code",
- "execution_count": 48,
- "id": "0cd64a1c-beb8-47d5-ac6f-e8525bb61176",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 51:> (0 + 1) / 1]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to German: This is so overly clichéd yo...| Das ist so übertrieben klischeehaft, dass Sie es nach den|\n",
- "|Translate English to German: I am a big fan of The ABC Mo...| Ich bin ein großer Fan von The ABC Movies of the Week|\n",
- "|Translate English to German: In the early 1990's \"Step-by...| Anfang der 1990er Jahre kam \"Step-by-Step\" als müh|\n",
- "|Translate English to German: When The Spirits Within was ...|Als The Spirits Within veröffentlicht wurde, hörten Sie v...|\n",
- "|Translate English to German: I like to think of myself as...| Ich halte mich gerne als schlechter Filmliebhaber|\n",
- "|Translate English to German: This film did well at the bo...|Dieser Film hat sich gut an der Boxoffice ereignet, und d...|\n",
- "|Translate English to German: Following the pleasingly atm...|Nach dem erfreulich stimmungsvollen Original und dem amüsant|\n",
- "|Translate English to German: I like CKY and Viva La Bam, ...| Ich mag CKY und Viva La Bam, also konnte ich mich nicht|\n",
- "|Translate English to German: I have read all of the revie...| Ich habe alle Rezensionen zu diesem direkten Film gelesen.|\n",
- "|Translate English to German: Yes, it was an awful movie, ...| Ja, es war ein schrecklicher Film, aber es gab|\n",
- "|Translate English to German: This was the most uninterest...|Dies war der größte Horrorfilm, den ich bisher gesehen habe.|\n",
- "|Translate English to German: I don't know if this excepti...|Ich weiß nicht, ob dieser außergewöhnlich langweilige Fil...|\n",
- "|Translate English to German: Heart of Darkness Movie Revi...|Herz der Dunkelheit Film Review Kann ein Buch, das für se...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie| Ein schlechter Film ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that thi...| Dieser Film wurde zwar fertiggestellt, aber es schien mir |\n",
- "|Translate English to German: Watching this movie, you jus...|Wenn man diesen Film anschaut, muss man einfach fragen: W...|\n",
- "| Translate English to German: OK, lets start with the best| OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) ...| Anna Christie (Greta Garbo) kehrt nach 15 Jahren zurück,|\n",
- "| Translate English to German: C| C|\n",
- "|Translate English to German: Tom and Jerry are transporti...|Tom und Jerry transportieren Güter über Flugzeug nach Afrika|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
+ "execution_count": 38,
+ "id": "aff88b3f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
+ " import numpy as np\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
+ "\n",
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " flattened = np.squeeze(inputs).tolist() \n",
+ " # Encode batch\n",
+ " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
+ " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
+ " # Run inference\n",
+ " result_data = client.infer_batch(encoded_batch_np)\n",
+ " result_data = np.squeeze(result_data[\"translations\"], -1)\n",
+ " return result_data\n",
+ " \n",
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a85e2ceb",
+ "metadata": {},
"source": [
- "preds.show(truncate=60)"
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 49,
- "id": "af70fed8-0f2b-4ea7-841c-476afdf9b1c0",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 39,
+ "id": "2fa3664e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
+ " @pandas_udf(\"string\")\n",
+ " def _preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n",
+ " return _preprocess(text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "5d6c54e7-534d-406f-b8e6-fd592efd0ab2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "dc1bbbe3-4232-49e5-80f6-99976524b73b",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/11 00:18:52 WARN CacheManager: Asked to cache already cached data.\n"
+ "25/01/06 20:53:05 WARN CacheManager: Asked to cache already cached data.\n"
]
}
],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df2 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to French: \")).select(\"input\").limit(100).cache()"
+ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e71f07d4",
+ "metadata": {},
+ "source": [
+ "#### Run Inference"
]
},
{
"cell_type": "code",
- "execution_count": 50,
- "id": "ef075e10-e22c-4236-9e0b-cb47cf2d3d06",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to French: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to French: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to French: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to French: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to French: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to French: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to French: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to French: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to French: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to French: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to French: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to French: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to French: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to French: OK, lets start with the best|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to French: C|\n",
- "| Translate English to French: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
- }
- ],
+ "execution_count": 42,
+ "id": "5d10c61c-6102-4d19-8dd6-0c7b5b65343e",
+ "metadata": {},
+ "outputs": [],
"source": [
- "df2.show(truncate=120)"
+ "generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
+ " return_type=StringType(),\n",
+ " input_tensor_shapes=[[1]],\n",
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 51,
- "id": "2e7e4af8-b815-4375-b851-8368309ee8e1",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 43,
+ "id": "2e0907da-a5d9-4c3b-9db4-ce5e70ca9bb4",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 55:> (0 + 1) / 1]\r"
+ "[Stage 50:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.91 ms, sys: 1.34 ms, total: 5.25 ms\n",
- "Wall time: 1.27 s\n"
+ "CPU times: user 12.8 ms, sys: 9.46 ms, total: 22.2 ms\n",
+ "Wall time: 30.6 s\n"
]
},
{
@@ -1636,33 +1394,30 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "# first pass caches model/fn\n",
+ "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 52,
- "id": "7b0aefb0-a96b-4791-a23c-1ce9b24eb20c",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 44,
+ "id": "9308bdd7-6f67-484d-8b51-dd1e1b2960ba",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 57:> (0 + 1) / 1]\r"
+ "[Stage 53:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 4.31 ms, sys: 0 ns, total: 4.31 ms\n",
- "Wall time: 1 s\n"
+ "CPU times: user 7.36 ms, sys: 8.88 ms, total: 16.2 ms\n",
+ "Wall time: 20.9 s\n"
]
},
{
@@ -1675,33 +1430,29 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(\"input\"))\n",
+ "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 53,
- "id": "1214b75b-a373-4579-b4c6-0cb8627da776",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 45,
+ "id": "38484ffd-370d-492b-8ca4-9eff9f242a9f",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 59:> (0 + 1) / 1]\r"
+ "[Stage 56:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.84 ms, sys: 1.31 ms, total: 4.15 ms\n",
- "Wall time: 990 ms\n"
+ "CPU times: user 10.8 ms, sys: 5.61 ms, total: 16.4 ms\n",
+ "Wall time: 21.9 s\n"
]
},
{
@@ -1714,55 +1465,51 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 54,
- "id": "c9dbd21f-9e37-4221-b765-80ba8c80b884",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 46,
+ "id": "ebcb6699-3ac2-4529-ab0f-fab0a5e792da",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 61:> (0 + 1) / 1]\r"
+ "[Stage 59:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to French: This is so overly clichéd yo...| Vous ne pouvez pas en tirer d'un tel cliché|\n",
- "|Translate English to French: I am a big fan of The ABC Mo...| Je suis un grand fan du genre The ABC Movies of the Week|\n",
- "|Translate English to French: In the early 1990's \"Step-by...| Au début des années 1990, «Step-by-Step» a été une|\n",
- "|Translate English to French: When The Spirits Within was ...|Lorsque The Spirits Within a été publié, tout ce que vous...|\n",
- "|Translate English to French: I like to think of myself as...| Je me considère comme un mauvais réalisateur de films|\n",
- "|Translate English to French: This film did well at the bo...| Ce film a bien avancé à la salle de cinéma et|\n",
- "|Translate English to French: Following the pleasingly atm...|Après l'original agréablement atmosphérique et la seconde...|\n",
- "|Translate English to French: I like CKY and Viva La Bam, ...| Je m'aime CKY et Viva La Bam, |\n",
- "|Translate English to French: I have read all of the revie...| J'ai lu tous les commentaires pour ce film direct à vidéo|\n",
- "|Translate English to French: Yes, it was an awful movie, ...| Oui, c'était un film terrible, mais il y avait une chanson|\n",
- "|Translate English to French: This was the most uninterest...| Ce fut le plus inquiétant et le plus inquiétant d'h|\n",
- "|Translate English to French: I don't know if this excepti...|Je ne sais pas si ce film extrêmement tacheté était desti...|\n",
- "|Translate English to French: Heart of Darkness Movie Revi...| Un livre connu pour son éloquence et ses concepts complexes|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie| Un mauvais film ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that thi...|En plus du fait que ce film a été réalisé (je suppose quil s|\n",
- "|Translate English to French: Watching this movie, you jus...|Vous devez simplement vous demander : « Que pense-t-il? » Il|\n",
- "| Translate English to French: OK, lets start with the best| OK, s'il y a lieu de commencer par le meilleur|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) ...|Anna Christie (Greta Garbo) retourne pour voir son père C...|\n",
- "| Translate English to French: C| C|\n",
- "|Translate English to French: Tom and Jerry are transporti...| Tom et Jerry transportent des marchandises par avion en |\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to French: Doesn't anyone bot...|Ne s'ennuie-t-il pas de vérifier où viennent ce...|\n",
+ "|translate English to French: There were two thi...|Il y avait deux choses que j'ai hâte de voir : ...|\n",
+ "|translate English to French: I'm rather surpris...|Je suis plutôt surpris que quelqu'un ait trouvé...|\n",
+ "|translate English to French: Cultural Vandalism...|Vandalisme culturel La nouvelle production Hall...|\n",
+ "|translate English to French: I was at Wrestlema...|J'étais à Wrestlemania VI à Toronto en 10 ans, ...|\n",
+ "|translate English to French: This movie has bee...| Ce film a été réalisé avant|\n",
+ "|translate English to French: [ as a new resolut...|[ en tant que nouvelle résolution pour cette an...|\n",
+ "|translate English to French: This movie is over...|Je suis triste de dire que je parviens à regard...|\n",
+ "|translate English to French: This show had a pr...|Ce spectacle a eu un début prometteur en l'espè...|\n",
+ "|translate English to French: MINOR PLOT SPOILER...|br />br /> Comment ces acteurs talentueux ont-i...|\n",
+ "|translate English to French: There is not one c...|Il n'y a pas d'un personnage sur ce sitcom ayan...|\n",
+ "|translate English to French: Tommy Lee Jones wa...|Tommy Lee Jones était le meilleur Woodroe et pe...|\n",
+ "|translate English to French: My wife rented thi...|Ma femme a loué ce film et n'a jamais pu le voi...|\n",
+ "|translate English to French: This is one of tho...|C’est l’une des comédies en étoiles à l’étoile ...|\n",
+ "|translate English to French: This excruciatingl...|Ce film excruciant ennuyant et infaillible m’a ...|\n",
+ "|translate English to French: you will likely be...|Vous serez probablement très déçu par cette séq...|\n",
+ "|translate English to French: If I was British, ...|Si j'étais britannique, je seraitis embarrassé ...|\n",
+ "|translate English to French: One of those movie...|Un des films dans lesquels il n'y a pas de gros...|\n",
+ "|translate English to French: This show is like ...|Ce spectacle ressemble à l'observation d'une pe...|\n",
+ "| translate English to French: Sigh| Pesée|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
@@ -1776,7 +1523,7 @@
}
],
"source": [
- "preds.show(truncate=60)"
+ "preds.show(truncate=50)"
]
},
{
@@ -1786,19 +1533,22 @@
"tags": []
},
"source": [
- "#### Stop Triton Server on each executor"
+ "#### Shut down server on each executor"
]
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 47,
"id": "425d3b28-7705-45ba-8a18-ad34fc895219",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -1812,31 +1562,39 @@
"[True]"
]
},
- "execution_count": 55,
+ "execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
+ " \n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": null,
"id": "2dec80ca-7a7c-46a9-97c0-7afb1572f5b9",
"metadata": {},
"outputs": [],
@@ -1855,7 +1613,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "spark-dl-tf",
+ "display_name": "test-tf",
"language": "python",
"name": "python3"
},
@@ -1869,7 +1627,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.11.9"
+ "version": "3.11.11"
}
},
"nbformat": 4,
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch copy.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch copy.ipynb
new file mode 100644
index 000000000..c872dcd93
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch copy.ipynb
@@ -0,0 +1,1951 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "8f6659b4-88da-4207-8d32-2674da5383a0",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "\n",
+ "\n",
+ "# PySpark DL Inference\n",
+ "### Conditional generation with Huggingface\n",
+ "\n",
+ "In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation. \n",
+ "From: https://huggingface.co/docs/transformers/model_doc/t5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
+ "\n",
+ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
+ "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
+ "import os\n",
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
+ ]
+ }
+ ],
+ "source": [
+ "tokenizer = T5Tokenizer.from_pretrained(\"google-t5/t5-small\")\n",
+ "model = T5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n",
+ "\n",
+ "task_prefix = \"translate English to German: \"\n",
+ "\n",
+ "lines = [\n",
+ " \"The house is wonderful\",\n",
+ " \"Welcome to NYC\",\n",
+ " \"HuggingFace is a company\"\n",
+ "]\n",
+ "\n",
+ "input_sequences = [task_prefix + l for l in lines]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "inputs = tokenizer(input_sequences,\n",
+ " padding=True, \n",
+ " return_tensors=\"pt\")\n",
+ "\n",
+ "outputs = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_length=128)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['Das Haus ist wunderbar',\n",
+ " 'Willkommen in NYC',\n",
+ " 'HuggingFace ist ein Unternehmen']"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "[tokenizer.decode(o, skip_special_tokens=True) for o in outputs]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## PySpark"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1b8dae4a-3bfc-4430-b28a-7350db5efed4",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from pyspark.sql.types import *\n",
+ "from pyspark import SparkConf\n",
+ "from pyspark.sql import SparkSession\n",
+ "from pyspark.sql.functions import pandas_udf, col, struct\n",
+ "from pyspark.ml.functions import predict_batch_udf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a93a1424-e483-4d37-a719-32fabee3f285",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import pandas as pd\n",
+ "import datasets\n",
+ "from datasets import load_dataset\n",
+ "datasets.disable_progress_bars()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Check the cluster environment to handle any platform-specific Spark configurations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Create Spark Session\n",
+ "\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/16 17:03:26 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/16 17:03:26 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "Setting default log level to \"WARN\".\n",
+ "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
+ "25/01/16 17:03:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ ]
+ }
+ ],
+ "source": [
+ "conf = SparkConf()\n",
+ "\n",
+ "if 'spark' not in globals():\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
+ "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
+ "sc = spark.sparkContext"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "f08c37a5-fb0c-45f6-8630-d2af67831641",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "Load the IMBD Movie Reviews dataset from Huggingface."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "f0ec30c9-365a-43c5-9c53-3497400ee548",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "dataset = load_dataset(\"imdb\", split=\"test\")\n",
+ "dataset = dataset.to_pandas().drop(columns=\"label\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1e4269da-d2b3-46a5-9309-38a1ba825a47",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "#### Create PySpark DataFrame"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "30dab34d-8e4b-4f30-b7c2-3dff49da018b",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "StructType([StructField('text', StringType(), True)])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = spark.createDataFrame(dataset).repartition(8)\n",
+ "df.schema"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "55c33cc0-5dfb-449c-ae79-80972fb04405",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "25000"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.count()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "efd6d6d9-1c2c-4131-8df4-a3ef75c3fc57",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/16 17:03:33 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.take(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "65a5b258-1634-441e-8b36-29777e54592d",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/16 17:03:33 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "89b909f4-5732-428b-ad61-9a6c5cf94df2",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "#### Load and preprocess DataFrame\n",
+ "\n",
+ "Define our preprocess function. We'll take the first sentence from each sample as our input for translation."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "eb7e53d6-bbd0-48d2-a3be-36847275e2a9",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
+ " @pandas_udf(\"string\")\n",
+ " def _preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n",
+ " return _preprocess(text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "97eee1a4-9dc4-43b0-9578-6d7f8ff338bd",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| text|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|The only reason I'm even giving this movie a 4 is because it was made in to an episode of Mystery...|\n",
+ "|Awkward disaster mishmash has a team of scavengers coming across the overturned S.S. Poseidon, ho...|\n",
+ "|Here is a fantastic concept for a film - a series of meteors crash into a small town and the resu...|\n",
+ "|I walked out of the cinema having suffered this film after 30 mins. I left two friends pinned in ...|\n",
+ "|A wildly uneven film where the major problem is the uneasy mix of comedy and thriller. To me, the...|\n",
+ "|Leonard Rossiter and Frances de la Tour carry this film, not without a struggle, as the script wa...|\n",
+ "|A good cast... A good idea but turns out it is flawed as hypnosis is not allowed as evidence in c...|\n",
+ "|Yet again, I appear to be the only person on planet Earth who is capable of criticizing Japanese ...|\n",
+ "|As a serious horror fan, I get that certain marketing ploys are used to sell movies, especially t...|\n",
+ "|Upon writing this review I have difficulty trying to think of what to write about. Nothing much h...|\n",
+ "|Simply awful. I'm including a spoiler warning here only because of including a coupla jokes from ...|\n",
+ "|I am a fan of Ed Harris' work and I really had high expectations about this film. Having so good ...|\n",
+ "|Well...I like Patricia Kaas. She is a beautiful lady and an extremely gifted and versatile singer...|\n",
+ "|This is a new approach to comedy. It isn't funny.
The joke is that this, in and of its...|\n",
+ "|It's been mentioned by others the inane dialogue in this series and I agree.
If Mom an...|\n",
+ "|One of the most boring movies I've ever had to sit through, it's completely formulaic. Just a coo...|\n",
+ "|This movie was playing on Lifetime Movie Network last month and I decided to check it out. I watc...|\n",
+ "|1983's \"Frightmare\" is an odd little film. The director seems to be trying to combine the atmosph...|\n",
+ "|'Felony' is a B-movie. No doubt about it.
Of course, if you take a look at the cast li...|\n",
+ "|This movie defines the word \"confused\". All the actors stay true to the script. More's the pity, ...|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = spark.read.parquet(data_path).limit(512).repartition(8)\n",
+ "df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Append a prefix to tell the model to translate English to French:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "fa14304d-b409-4d07-99ef-9da7c7c76158",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| input|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|translate English to French: The only reason I'm even giving this movie a 4 is because it was mad...|\n",
+ "|translate English to French: Awkward disaster mishmash has a team of scavengers coming across the...|\n",
+ "|translate English to French: Here is a fantastic concept for a film - a series of meteors crash i...|\n",
+ "| translate English to French: I walked out of the cinema having suffered this film after 30 mins|\n",
+ "|translate English to French: A wildly uneven film where the major problem is the uneasy mix of co...|\n",
+ "|translate English to French: Leonard Rossiter and Frances de la Tour carry this film, not without...|\n",
+ "| translate English to French: A good cast|\n",
+ "|translate English to French: Yet again, I appear to be the only person on planet Earth who is cap...|\n",
+ "|translate English to French: As a serious horror fan, I get that certain marketing ploys are used...|\n",
+ "|translate English to French: Upon writing this review I have difficulty trying to think of what t...|\n",
+ "| translate English to French: Simply awful|\n",
+ "|translate English to French: I am a fan of Ed Harris' work and I really had high expectations abo...|\n",
+ "| translate English to French: Well|\n",
+ "| translate English to French: This is a new approach to comedy|\n",
+ "|translate English to French: It's been mentioned by others the inane dialogue in this series and ...|\n",
+ "|translate English to French: One of the most boring movies I've ever had to sit through, it's com...|\n",
+ "|translate English to French: This movie was playing on Lifetime Movie Network last month and I de...|\n",
+ "| translate English to French: 1983's \"Frightmare\" is an odd little film|\n",
+ "| translate English to French: 'Felony' is a B-movie|\n",
+ "| translate English to French: This movie defines the word \"confused\"|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()\n",
+ "input_df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "bc9cbdd2-1ca6-48e4-a549-792b3726525b",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "adb81177-442d-42ab-b86d-d8792201b4c8",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def predict_batch_fn():\n",
+ " import numpy as np\n",
+ " import torch\n",
+ " from transformers import T5ForConditionalGeneration, T5Tokenizer\n",
+ " from pyspark import TaskContext\n",
+ "\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " print(f\"Initializing model on worker {TaskContext.get().partitionId()}, device {device}.\")\n",
+ " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\").to(device)\n",
+ " tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
+ "\n",
+ " def predict(inputs):\n",
+ " flattened = np.squeeze(inputs).tolist()\n",
+ " inputs = tokenizer(flattened, \n",
+ " padding=True,\n",
+ " return_tensors=\"pt\").to(device)\n",
+ " outputs = model.generate(input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_length=128)\n",
+ " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\n",
+ " print(\"predict: {}\".format(len(flattened)))\n",
+ " return string_outputs\n",
+ " \n",
+ " return predict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "20aab3a1-2284-4c07-9ce1-a20cf54d88f3",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "generate = predict_batch_udf(predict_batch_fn,\n",
+ " return_type=StringType(),\n",
+ " batch_size=32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a8d6f48e-09e7-4fc7-9d2f-1b68bc2976a7",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 24:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5.13 ms, sys: 3.73 ms, total: 8.85 ms\n",
+ "Wall time: 6.94 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# first pass caches model/fn\n",
+ "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "results = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "abe2271d-0077-48f6-98b1-93524dd86447",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 27:====================================> (5 + 3) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 4.55 ms, sys: 2.75 ms, total: 7.3 ms\n",
+ "Wall time: 3.93 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n",
+ "results = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "77623711-a742-4262-8839-16fc3ddd1af7",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 30:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 4.43 ms, sys: 1.82 ms, total: 6.25 ms\n",
+ "Wall time: 3.91 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "results = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "f339c654-52fd-4992-b054-188dfb260e5d",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\n",
+ "|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\n",
+ "|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\n",
+ "|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\n",
+ "|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\n",
+ "|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\n",
+ "| translate English to French: A good cast| Une bonne étoile|\n",
+ "|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\n",
+ "|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\n",
+ "|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\n",
+ "| translate English to French: Simply awful| Tout simplement terrible|\n",
+ "|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\n",
+ "| translate English to French: Well| Eh bien|\n",
+ "|translate English to French: This is a new appr...| Il s’agit d’une nouvelle approche de la comédie.|\n",
+ "|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\n",
+ "|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\n",
+ "|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\n",
+ "|translate English to French: 1983's \"Frightmare...|Le film \"Frightmare\" de 1983 est un petit film ...|\n",
+ "|translate English to French: 'Felony' is a B-movie| 'Felony' est un mouvement B|\n",
+ "|translate English to French: This movie defines...| Ce film définit le mot «confus»|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "preds.show(truncate=50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's try English to German:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_df2 = df.select(preprocess(col(\"text\"), \"translate English to German: \").alias(\"input\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 36:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 6.15 ms, sys: 1.36 ms, total: 7.51 ms\n",
+ "Wall time: 3.99 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# first pass caches model/fn\n",
+ "preds = input_df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "result = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 39:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 7.83 ms, sys: 0 ns, total: 7.83 ms\n",
+ "Wall time: 3.69 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "preds = input_df2.withColumn(\"preds\", generate(\"input\"))\n",
+ "result = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 42:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 3.59 ms, sys: 3.38 ms, total: 6.96 ms\n",
+ "Wall time: 3.68 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "preds = input_df2.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "result = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to German: The only reason I'...|Der einzige Grund, warum ich sogar diesen Film ...|\n",
+ "|translate English to German: Awkward disaster m...|Awkward-Katastrophenmischmash hat ein Team von ...|\n",
+ "|translate English to German: Here is a fantasti...|Hier ist ein fantastisches Konzept für einen Fi...|\n",
+ "|translate English to German: I walked out of th...|Ich ging aus dem Kino, nachdem ich diesen Film ...|\n",
+ "|translate English to German: A wildly uneven fi...|Ein völlig ungleicher Film, in dem das Hauptpro...|\n",
+ "|translate English to German: Leonard Rossiter a...|Leonard Rossiter und Frances de la Tour tragen ...|\n",
+ "| translate English to German: A good cast| Gutes Casting|\n",
+ "|translate English to German: Yet again, I appea...|Ich scheine wieder einmal die einzige Person au...|\n",
+ "|translate English to German: As a serious horro...|Als ernsthafter Horrorfan erhalte ich, dass bes...|\n",
+ "|translate English to German: Upon writing this ...|Ich habe Schwierigkeiten, mich an die Regeln zu...|\n",
+ "| translate English to German: Simply awful| Einfach schrecklich|\n",
+ "|translate English to German: I am a fan of Ed H...|Ich bin ein Fan von Ed Harris' Arbeit und hatte...|\n",
+ "| translate English to German: Well| Nun|\n",
+ "|translate English to German: This is a new appr...| Das ist ein neuer Ansatz für die Komödie|\n",
+ "|translate English to German: It's been mentione...|Es wurde von anderen erwähnt, die unangenehme D...|\n",
+ "|translate English to German: One of the most bo...|Einer der langwierigen Filme, die ich jemals du...|\n",
+ "|translate English to German: This movie was pla...|Dieser Film spielte im letzten Monat auf Lifeti...|\n",
+ "|translate English to German: 1983's \"Frightmare...| 1983 ist \"Frightmare\" ein merkwürdiger Film|\n",
+ "|translate English to German: 'Felony' is a B-movie| 'Felony' ist ein B-Film|\n",
+ "|translate English to German: This movie defines...| Dieser Film definiert das Wort \"verwirrt\"|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "preds.show(truncate=50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a79a6f3a-cc34-46a4-aadd-16870423fffa",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
+ "\n",
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1e73757e-a451-4835-98e0-257ccf7a9025",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "71b1cb49-3d8f-4eeb-937a-c0c334bd2947",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def triton_server(ports):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import torch\n",
+ " from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
+ "\n",
+ " print(f\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\")\n",
+ " tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
+ " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
+ " \n",
+ " DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ " print(f\"SERVER: Using {DEVICE} device.\")\n",
+ " model = model.to(DEVICE)\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = np.squeeze(inputs[\"text\"]).tolist()\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
+ " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
+ " inputs = tokenizer(decoded_sentences,\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\").to(DEVICE)\n",
+ " output_ids = model.generate(input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_length=128)\n",
+ " outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\n",
+ " return {\n",
+ " \"translations\": outputs,\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"ConditionalGeneration\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"translations\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=64,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1bf14846-15a3-4bc8-b0c5-ce71680d3550",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "e2c40df2-161b-483d-9d10-e462ecfb9fed",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "5bf1fafc-d9c9-4fd7-901d-da97cf4ff496",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using ports [7000, 7001, 7002]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name = \"ConditionalGeneration\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "289b08fa-7916-44ea-8fe5-28821451db6b",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 46:> (0 + 1) / 1]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2165024\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "url = f\"grpc://localhost:{ports[1]}\" # or f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "e203eb19-166d-4177-aa87-fd31b7e3c90e",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
+ " import numpy as np\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
+ "\n",
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " flattened = np.squeeze(inputs).tolist() \n",
+ " # Encode batch\n",
+ " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
+ " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
+ " # Run inference\n",
+ " result_data = client.infer_batch(encoded_batch_np)\n",
+ " result_data = np.squeeze(result_data[\"translations\"], -1)\n",
+ " return result_data\n",
+ " \n",
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1b6b2a05-aea4-4e4d-a87d-0a6bd5ab554c",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "#### Load and preprocess DataFrame"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 48,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a5e83230-5178-4fec-bba2-0e69be40e68c",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
+ " @pandas_udf(\"string\")\n",
+ " def _preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n",
+ " return _preprocess(text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 49,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "aad299b0-34bb-4edb-b1e4-cd0c82bb7455",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "df = spark.read.parquet(data_path).limit(512).repartition(8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 50,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "7934a6fc-57bc-4104-a52c-076351e77cbe",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/16 17:04:51 WARN CacheManager: Asked to cache already cached data.\n"
+ ]
+ }
+ ],
+ "source": [
+ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Run Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "be692f4a-cf86-4cf4-9530-7c62e479cacd",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
+ " return_type=StringType(),\n",
+ " input_tensor_shapes=[[1]],\n",
+ " batch_size=32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 52,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "0f6229ef-01c8-43c9-a259-c5df6a18d689",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 63:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 9.58 ms, sys: 1.8 ms, total: 11.4 ms\n",
+ "Wall time: 4.86 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# first pass caches model/fn\n",
+ "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "results = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 53,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "5a543b4c-8b29-4f61-9773-2639bbc7f728",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 66:===========================================> (6 + 2) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 5.19 ms, sys: 2.36 ms, total: 7.55 ms\n",
+ "Wall time: 4.41 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n",
+ "results = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "4c0cfc4e-ef0a-435e-9fdf-72b72b6def93",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 69:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 7.08 ms, sys: 2.41 ms, total: 9.49 ms\n",
+ "Wall time: 4.27 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "results = preds.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "2d756e2e-8b60-43cb-b5f9-e27de11be24d",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\n",
+ "|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\n",
+ "|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\n",
+ "|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\n",
+ "|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\n",
+ "|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\n",
+ "| translate English to French: A good cast| Une bonne étoile|\n",
+ "|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\n",
+ "|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\n",
+ "|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\n",
+ "| translate English to French: Simply awful| Tout simplement terrible|\n",
+ "|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\n",
+ "| translate English to French: Well| Eh bien|\n",
+ "|translate English to French: This is a new appr...| Il s’agit d’une nouvelle approche de la comédie.|\n",
+ "|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\n",
+ "|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\n",
+ "|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\n",
+ "|translate English to French: 1983's \"Frightmare...|Le film \"Frightmare\" de 1983 est un petit film ...|\n",
+ "|translate English to French: 'Felony' is a B-movie| 'Felony' est un mouvement B|\n",
+ "|translate English to French: This movie defines...| Ce film définit le mot «confus»|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "preds.show(truncate=50)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "86ae68d4-57da-41d9-91b4-625ef9465d60",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "#### Shut down servers on each executor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "16fd4601-f6d5-4ddf-9b5e-d918ab0adf3a",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[True]"
+ ]
+ },
+ "execution_count": 56,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
+ " \n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
+ " \n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
+ "\n",
+ " return [False]\n",
+ "\n",
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "spark.stop()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "008c3e50-d321-4431-a9ab-919b35d1b042",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "application/vnd.databricks.v1+notebook": {
+ "dashboards": [],
+ "environmentMetadata": null,
+ "language": "python",
+ "notebookMetadata": {
+ "mostRecentlyExecutedCommandWithImplicitDF": {
+ "commandId": 421988607303514,
+ "dataframes": [
+ "_sqldf"
+ ]
+ },
+ "pythonIndentUnit": 4
+ },
+ "notebookName": "spark-triton-db.ipynb",
+ "widgets": {}
+ },
+ "kernelspec": {
+ "display_name": "spark-dl-torch",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch.ipynb
index 94cb7df19..f96ee84df 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/conditional_generation_torch.ipynb
@@ -2,58 +2,56 @@
"cells": [
{
"cell_type": "markdown",
- "id": "777fc40d",
- "metadata": {},
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "8f6659b4-88da-4207-8d32-2674da5383a0",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"source": [
- "# PySpark Huggingface Inferencing\n",
- "## Conditional generation with PyTorch\n",
+ "\n",
+ "\n",
+ "# PySpark DL Inference\n",
+ "### Conditional generation with Huggingface\n",
"\n",
+ "In this notebook, we demonstrate distributed inference with the T5 transformer to perform sentence translation. \n",
"From: https://huggingface.co/docs/transformers/model_doc/t5"
]
},
{
"cell_type": "code",
"execution_count": 1,
- "id": "c0eed0e8",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
- ]
- }
- ],
- "source": [
- "from transformers import T5Tokenizer, T5ForConditionalGeneration"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "041ca559",
- "metadata": {},
- "source": [
- "Enabling Huggingface tokenizer parallelism so that it is not automatically disabled with Python parallelism. See [this thread](https://github.com/huggingface/transformers/issues/5486) for more info. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "6695a3e5",
"metadata": {},
"outputs": [],
"source": [
+ "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
+ "\n",
+ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
+ "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
"import os\n",
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "900d6506",
+ "execution_count": 2,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
+ ]
+ }
+ ],
"source": [
"tokenizer = T5Tokenizer.from_pretrained(\"google-t5/t5-small\")\n",
"model = T5ForConditionalGeneration.from_pretrained(\"google-t5/t5-small\")\n",
@@ -71,24 +69,20 @@
},
{
"cell_type": "code",
- "execution_count": 2,
- "id": "73655aea",
+ "execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
- "input_ids = tokenizer(input_sequences, \n",
- " padding=\"longest\", \n",
- " max_length=512,\n",
- " truncation=True,\n",
- " return_tensors=\"pt\").input_ids\n",
+ "inputs = tokenizer(input_sequences,\n",
+ " padding=True, \n",
+ " return_tensors=\"pt\")\n",
"\n",
- "outputs = model.generate(input_ids, max_length=20)"
+ "outputs = model.generate(input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_length=128)"
]
},
{
"cell_type": "code",
- "execution_count": 3,
- "id": "90e54262",
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -99,7 +93,7 @@
" 'HuggingFace ist ein Unternehmen']"
]
},
- "execution_count": 3,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -109,187 +103,293 @@
]
},
{
- "cell_type": "code",
- "execution_count": 4,
- "id": "6b11c89a",
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'pt'"
- ]
+ "source": [
+ "## PySpark"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
},
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
+ "inputWidgets": {},
+ "nuid": "1b8dae4a-3bfc-4430-b28a-7350db5efed4",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
}
- ],
+ },
+ "outputs": [],
"source": [
- "model.framework"
+ "from pyspark.sql.types import *\n",
+ "from pyspark import SparkConf\n",
+ "from pyspark.sql import SparkSession\n",
+ "from pyspark.sql.functions import pandas_udf, col, struct\n",
+ "from pyspark.ml.functions import predict_batch_udf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a93a1424-e483-4d37-a719-32fabee3f285",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import pandas as pd\n",
+ "import datasets\n",
+ "from datasets import load_dataset\n",
+ "datasets.disable_progress_bars()"
]
},
{
"cell_type": "markdown",
- "id": "546eabe0",
"metadata": {},
"source": [
- "## PySpark"
+ "Check the cluster environment to handle any platform-specific Spark configurations."
]
},
{
"cell_type": "code",
- "execution_count": 1,
- "id": "2f6db1f0-7d68-4af7-8bd6-c9fa45906c61",
+ "execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
- "import os\n",
- "from pathlib import Path\n",
- "from datasets import load_dataset"
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
]
},
{
- "cell_type": "code",
- "execution_count": 2,
- "id": "68121304-f1df-466e-9347-c9d2b36a9b3a",
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "from pyspark.sql.types import *\n",
- "from pyspark.sql import SparkSession\n",
- "from pyspark import SparkConf"
+ "#### Create Spark Session\n",
+ "\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
]
},
{
"cell_type": "code",
- "execution_count": 3,
- "id": "6279a849",
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/10 00:10:48 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
- "24/10/10 00:10:48 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "25/01/06 18:28:16 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/06 18:28:16 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
- "24/10/10 00:10:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ "25/01/06 18:28:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
"source": [
- "import os\n",
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
- "\n",
"conf = SparkConf()\n",
+ "\n",
"if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"512\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
"spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
"sc = spark.sparkContext"
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "f08c37a5-fb0c-45f6-8630-d2af67831641",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "Load the IMBD Movie Reviews dataset from Huggingface."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 4,
- "id": "b8453111-d068-49bb-ab91-8ae3d8bcdb7a",
- "metadata": {},
+ "execution_count": 9,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "f0ec30c9-365a-43c5-9c53-3497400ee548",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [],
"source": [
- "# load IMDB reviews (test) dataset\n",
- "data = load_dataset(\"imdb\", split=\"test\")"
+ "dataset = load_dataset(\"imdb\", split=\"test\")\n",
+ "dataset = dataset.to_pandas().drop(columns=\"label\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1e4269da-d2b3-46a5-9309-38a1ba825a47",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "#### Create PySpark DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "7ad01d4a",
- "metadata": {},
+ "execution_count": 10,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "30dab34d-8e4b-4f30-b7c2-3dff49da018b",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"data": {
"text/plain": [
- "25000"
+ "StructType([StructField('text', StringType(), True)])"
]
},
- "execution_count": 5,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "lines = []\n",
- "for example in data:\n",
- " lines.append([example[\"text\"].split(\".\")[0]])\n",
- "\n",
- "len(lines)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "6fd5b472-47e8-4804-9907-772793fedb2b",
- "metadata": {},
- "source": [
- "### Create PySpark DataFrame"
+ "df = spark.createDataFrame(dataset).repartition(8)\n",
+ "df.schema"
]
},
{
"cell_type": "code",
- "execution_count": 6,
- "id": "d24d9404-0269-476e-a9dd-1842667c915a",
- "metadata": {},
+ "execution_count": 11,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "55c33cc0-5dfb-449c-ae79-80972fb04405",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"data": {
"text/plain": [
- "StructType([StructField('lines', StringType(), True)])"
+ "25000"
]
},
- "execution_count": 6,
+ "execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "df = spark.createDataFrame(lines, ['lines']).repartition(8)\n",
- "df.schema"
+ "df.count()"
]
},
{
"cell_type": "code",
- "execution_count": 7,
- "id": "4384c762-1f79-4f60-876c-94b1f552e8fb",
- "metadata": {},
+ "execution_count": 12,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "efd6d6d9-1c2c-4131-8df4-a3ef75c3fc57",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "25/01/06 18:28:23 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
]
},
{
"data": {
"text/plain": [
- "[Row(lines='(Some Spoilers) Dull as dishwater slasher flick that has this deranged homeless man Harry, Darwyn Swalve, out murdering real-estate agent all over the city of L')]"
+ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
]
},
- "execution_count": 7,
+ "execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
@@ -298,77 +398,79 @@
"df.take(1)"
]
},
- {
- "cell_type": "markdown",
- "id": "42ba3513-82dd-47e7-8193-eb4389458757",
- "metadata": {},
- "source": [
- "### Save the test dataset as parquet files"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "e7eec8ec-4126-4890-b957-025809fad67d",
- "metadata": {},
- "outputs": [],
- "source": [
- "df.write.mode(\"overwrite\").parquet(\"imdb_test\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "304e1fc8-42a3-47dd-b3c0-47efd5be1040",
- "metadata": {},
- "source": [
- "### Check arrow memory configuration"
- ]
- },
{
"cell_type": "code",
- "execution_count": 9,
- "id": "20554ea5-01be-4a30-8607-db5d87786fec",
- "metadata": {},
- "outputs": [],
+ "execution_count": 13,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "65a5b258-1634-441e-8b36-29777e54592d",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 18:28:23 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ }
+ ],
"source": [
- "if int(spark.conf.get(\"spark.sql.execution.arrow.maxRecordsPerBatch\")) > 512:\n",
- " print(\"Decreasing `spark.sql.execution.arrow.maxRecordsPerBatch` to ensure the vectorized reader won't run out of memory\")\n",
- " spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"512\")\n",
- "assert len(df.head()) > 0, \"`df` should not be empty\""
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
]
},
{
"cell_type": "markdown",
- "id": "06a4ecab-c9d9-466f-ba49-902ad1fd5488",
- "metadata": {},
- "source": [
- "## Inference using Spark DL API\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "e7a00479-1347-4de8-8431-faa77f8cdf4c",
"metadata": {
- "tags": []
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "89b909f4-5732-428b-ad61-9a6c5cf94df2",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
- "outputs": [],
"source": [
- "import pandas as pd\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, pandas_udf, struct\n",
- "from pyspark.sql.types import StringType"
+ "#### Load and preprocess DataFrame\n",
+ "\n",
+ "Define our preprocess function. We'll take the first sentence from each sample as our input for translation."
]
},
{
"cell_type": "code",
- "execution_count": 11,
- "id": "b9a0889a-35b4-493a-8197-1146fc7efd53",
- "metadata": {},
+ "execution_count": 14,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "eb7e53d6-bbd0-48d2-a3be-36847275e2a9",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [],
"source": [
- "# only use first sentence and add prefix for conditional generation\n",
"def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
" @pandas_udf(\"string\")\n",
" def _preprocess(text: pd.Series) -> pd.Series:\n",
@@ -378,160 +480,191 @@
},
{
"cell_type": "code",
- "execution_count": 12,
- "id": "c483e4d4-9ab1-416f-a766-694e17490fd3",
- "metadata": {},
+ "execution_count": 15,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "97eee1a4-9dc4-43b0-9578-6d7f8ff338bd",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| lines|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| I am a big fan of The ABC Movies of the Week genre|\n",
- "|In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Full House\" and the long-defunc...|\n",
- "|When The Spirits Within was released, all you heard from Final Fantasy fans was how awful the movie was because it di...|\n",
- "| I like to think of myself as a bad movie connoisseur|\n",
- "|This film did well at the box office, and the producers of this mess thought the stars had such good chemistry in thi...|\n",
- "|Following the pleasingly atmospheric original and the amusingly silly second one, this incredibly dull, slow, and une...|\n",
- "| I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| I have read all of the reviews for this direct to video movie|\n",
- "|Yes, it was an awful movie, but there was a song near the beginning of the movie, I think, called \"I got a Woody\" or ...|\n",
- "| This was the most uninteresting horror flick I have seen to date|\n",
- "|I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'The French Connection\", but it...|\n",
- "|Heart of Darkness Movie Review Could a book that is well known for its eloquent wording and complicated concepts ever...|\n",
- "| A bad movie ABOUT a bad movie|\n",
- "|Apart from the fact that this film was made ( I suppose it seemed a good idea at the time considering BOTTOM was so p...|\n",
- "|Watching this movie, you just have to ask: What were they thinking? There are so many noticeably bad parts of this mo...|\n",
- "| OK, lets start with the best|\n",
- "| Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 years|\n",
- "| C|\n",
- "| Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| text|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|The only reason I'm even giving this movie a 4 is because it was made in to an episode of Mystery...|\n",
+ "|Awkward disaster mishmash has a team of scavengers coming across the overturned S.S. Poseidon, ho...|\n",
+ "|Here is a fantastic concept for a film - a series of meteors crash into a small town and the resu...|\n",
+ "|I walked out of the cinema having suffered this film after 30 mins. I left two friends pinned in ...|\n",
+ "|A wildly uneven film where the major problem is the uneasy mix of comedy and thriller. To me, the...|\n",
+ "|Leonard Rossiter and Frances de la Tour carry this film, not without a struggle, as the script wa...|\n",
+ "|A good cast... A good idea but turns out it is flawed as hypnosis is not allowed as evidence in c...|\n",
+ "|Yet again, I appear to be the only person on planet Earth who is capable of criticizing Japanese ...|\n",
+ "|As a serious horror fan, I get that certain marketing ploys are used to sell movies, especially t...|\n",
+ "|Upon writing this review I have difficulty trying to think of what to write about. Nothing much h...|\n",
+ "|Simply awful. I'm including a spoiler warning here only because of including a coupla jokes from ...|\n",
+ "|I am a fan of Ed Harris' work and I really had high expectations about this film. Having so good ...|\n",
+ "|Well...I like Patricia Kaas. She is a beautiful lady and an extremely gifted and versatile singer...|\n",
+ "|This is a new approach to comedy. It isn't funny.
The joke is that this, in and of its...|\n",
+ "|It's been mentioned by others the inane dialogue in this series and I agree.
If Mom an...|\n",
+ "|One of the most boring movies I've ever had to sit through, it's completely formulaic. Just a coo...|\n",
+ "|This movie was playing on Lifetime Movie Network last month and I decided to check it out. I watc...|\n",
+ "|1983's \"Frightmare\" is an odd little film. The director seems to be trying to combine the atmosph...|\n",
+ "|'Felony' is a B-movie. No doubt about it.
Of course, if you take a look at the cast li...|\n",
+ "|This movie defines the word \"confused\". All the actors stay true to the script. More's the pity, ...|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
- },
- {
- "data": {
- "text/plain": [
- "100"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
- "# only use first N examples, since this is slow\n",
- "df = spark.read.parquet(\"imdb_test\").limit(100)\n",
- "df.show(truncate=120)\n",
- "df.count()"
+ "df = spark.read.parquet(data_path).limit(512).repartition(8)\n",
+ "df.show(truncate=100)"
]
},
{
- "cell_type": "code",
- "execution_count": 13,
- "id": "831bc52c-a5c6-4c29-a6da-0566b5167773",
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df1 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to German: \")).select(\"input\").limit(100).cache()"
+ "Append a prefix to tell the model to translate English to French:"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "46dac59c-5a54-4576-91e0-279c8b375b95",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "100"
- ]
+ "execution_count": 16,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
},
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
+ "inputWidgets": {},
+ "nuid": "fa14304d-b409-4d07-99ef-9da7c7c76158",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
}
- ],
- "source": [
- "df1.count()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "id": "fef1d846-5852-4762-8527-602f32c0d7cd",
- "metadata": {},
+ },
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to German: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to German: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to German: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to German: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to German: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to German: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to German: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to German: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to German: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to German: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to German: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to German: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to German: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to German: OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to German: C|\n",
- "| Translate English to German: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| input|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|translate English to French: The only reason I'm even giving this movie a 4 is because it was mad...|\n",
+ "|translate English to French: Awkward disaster mishmash has a team of scavengers coming across the...|\n",
+ "|translate English to French: Here is a fantastic concept for a film - a series of meteors crash i...|\n",
+ "| translate English to French: I walked out of the cinema having suffered this film after 30 mins|\n",
+ "|translate English to French: A wildly uneven film where the major problem is the uneasy mix of co...|\n",
+ "|translate English to French: Leonard Rossiter and Frances de la Tour carry this film, not without...|\n",
+ "| translate English to French: A good cast|\n",
+ "|translate English to French: Yet again, I appear to be the only person on planet Earth who is cap...|\n",
+ "|translate English to French: As a serious horror fan, I get that certain marketing ploys are used...|\n",
+ "|translate English to French: Upon writing this review I have difficulty trying to think of what t...|\n",
+ "| translate English to French: Simply awful|\n",
+ "|translate English to French: I am a fan of Ed Harris' work and I really had high expectations abo...|\n",
+ "| translate English to French: Well|\n",
+ "| translate English to French: This is a new approach to comedy|\n",
+ "|translate English to French: It's been mentioned by others the inane dialogue in this series and ...|\n",
+ "|translate English to French: One of the most boring movies I've ever had to sit through, it's com...|\n",
+ "|translate English to French: This movie was playing on Lifetime Movie Network last month and I de...|\n",
+ "| translate English to French: 1983's \"Frightmare\" is an odd little film|\n",
+ "| translate English to French: 'Felony' is a B-movie|\n",
+ "| translate English to French: This movie defines the word \"confused\"|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "df1.show(truncate=120)"
+ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()\n",
+ "input_df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "bc9cbdd2-1ca6-48e4-a549-792b3726525b",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
]
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "e7ae69d3-70c2-4765-928f-c96a7ba59829",
- "metadata": {},
+ "execution_count": 17,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "adb81177-442d-42ab-b86d-d8792201b4c8",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [],
"source": [
"def predict_batch_fn():\n",
" import numpy as np\n",
+ " import torch\n",
" from transformers import T5ForConditionalGeneration, T5Tokenizer\n",
+ " from pyspark import TaskContext\n",
"\n",
- " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " print(f\"Initializing model on worker {TaskContext.get().partitionId()}, device {device}.\")\n",
+ " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\").to(device)\n",
" tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
"\n",
" def predict(inputs):\n",
- " flattened = np.squeeze(inputs).tolist() # convert 2d numpy array of string into flattened python list\n",
- " input_ids = tokenizer(flattened, \n",
- " padding=\"longest\", \n",
- " max_length=128,\n",
- " truncation=True,\n",
- " return_tensors=\"pt\").input_ids\n",
- " output_ids = model.generate(input_ids, max_length=20)\n",
- " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])\n",
+ " flattened = np.squeeze(inputs).tolist()\n",
+ " inputs = tokenizer(flattened, \n",
+ " padding=True,\n",
+ " return_tensors=\"pt\").to(device)\n",
+ " outputs = model.generate(input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_length=128)\n",
+ " string_outputs = np.array([tokenizer.decode(o, skip_special_tokens=True) for o in outputs])\n",
" print(\"predict: {}\".format(len(flattened)))\n",
- " \n",
" return string_outputs\n",
" \n",
" return predict"
@@ -539,35 +672,57 @@
},
{
"cell_type": "code",
- "execution_count": 17,
- "id": "36684f59-d947-43f8-a2e8-c7a423764e88",
- "metadata": {},
+ "execution_count": 18,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "20aab3a1-2284-4c07-9ce1-a20cf54d88f3",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [],
"source": [
"generate = predict_batch_udf(predict_batch_fn,\n",
" return_type=StringType(),\n",
- " batch_size=10)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 18,
- "id": "6a01c855-8fa1-4765-a3a5-2c9dd872df10",
- "metadata": {},
+ "execution_count": 19,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a8d6f48e-09e7-4fc7-9d2f-1b68bc2976a7",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 21:> (0 + 1) / 1]\r"
+ "[Stage 24:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 6.58 ms, sys: 4.68 ms, total: 11.3 ms\n",
- "Wall time: 7.41 s\n"
+ "CPU times: user 7.33 ms, sys: 1.47 ms, total: 8.8 ms\n",
+ "Wall time: 6.8 s\n"
]
},
{
@@ -581,29 +736,40 @@
"source": [
"%%time\n",
"# first pass caches model/fn\n",
- "preds = df1.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 19,
- "id": "d912d4b0-cd0b-44ea-859a-b23455cc2700",
- "metadata": {},
+ "execution_count": 20,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "abe2271d-0077-48f6-98b1-93524dd86447",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 23:> (0 + 1) / 1]\r"
+ "[Stage 27:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1.87 ms, sys: 1.8 ms, total: 3.67 ms\n",
- "Wall time: 5.71 s\n"
+ "CPU times: user 7 ms, sys: 1.36 ms, total: 8.36 ms\n",
+ "Wall time: 3.91 s\n"
]
},
{
@@ -616,29 +782,40 @@
],
"source": [
"%%time\n",
- "preds = df1.withColumn(\"preds\", generate(\"input\"))\n",
+ "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 20,
- "id": "5fe3d88b-30f7-468f-8db8-1f4118d0f26c",
- "metadata": {},
+ "execution_count": 21,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "77623711-a742-4262-8839-16fc3ddd1af7",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 25:> (0 + 1) / 1]\r"
+ "[Stage 30:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.99 ms, sys: 1.42 ms, total: 4.42 ms\n",
- "Wall time: 5.69 s\n"
+ "CPU times: user 4.63 ms, sys: 2.59 ms, total: 7.22 ms\n",
+ "Wall time: 3.9 s\n"
]
},
{
@@ -651,140 +828,98 @@
],
"source": [
"%%time\n",
- "preds = df1.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 21,
- "id": "4ad9b365-4b9a-438e-8fdf-47da55cb1cf4",
- "metadata": {},
+ "execution_count": 22,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "f339c654-52fd-4992-b054-188dfb260e5d",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 27:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to German: This is so overly clichéd yo...| Das ist so übertrieben klischeehaft, dass Sie es nach den|\n",
- "|Translate English to German: I am a big fan of The ABC Mo...| Ich bin ein großer Fan von The ABC Movies of the Week|\n",
- "|Translate English to German: In the early 1990's \"Step-by...| Anfang der 1990er Jahre kam \"Step-by-Step\" als müh|\n",
- "|Translate English to German: When The Spirits Within was ...|Als The Spirits Within veröffentlicht wurde, hörten Sie v...|\n",
- "|Translate English to German: I like to think of myself as...| Ich halte mich gerne als schlechter Filmliebhaber|\n",
- "|Translate English to German: This film did well at the bo...|Dieser Film hat sich gut an der Boxoffice ereignet, und d...|\n",
- "|Translate English to German: Following the pleasingly atm...|Nach dem erfreulich stimmungsvollen Original und dem amüsant|\n",
- "|Translate English to German: I like CKY and Viva La Bam, ...| Ich mag CKY und Viva La Bam, also konnte ich mich nicht|\n",
- "|Translate English to German: I have read all of the revie...| Ich habe alle Rezensionen zu diesem direkten Film gelesen.|\n",
- "|Translate English to German: Yes, it was an awful movie, ...| Ja, es war ein schrecklicher Film, aber es gab|\n",
- "|Translate English to German: This was the most uninterest...|Dies war der größte Horrorfilm, den ich bisher gesehen habe.|\n",
- "|Translate English to German: I don't know if this excepti...|Ich weiß nicht, ob dieser außergewöhnlich langweilige Fil...|\n",
- "|Translate English to German: Heart of Darkness Movie Revi...|Herz der Dunkelheit Film Review Kann ein Buch, das für se...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie| Ein schlechter Film ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that thi...| Dieser Film wurde zwar fertiggestellt, aber es schien mir|\n",
- "|Translate English to German: Watching this movie, you jus...|Wenn man diesen Film anschaut, muss man einfach fragen: W...|\n",
- "| Translate English to German: OK, lets start with the best| OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) ...| Anna Christie (Greta Garbo) kehrt nach 15 Jahren zurück,|\n",
- "| Translate English to German: C| C|\n",
- "|Translate English to German: Tom and Jerry are transporti...|Tom und Jerry transportieren Güter über Flugzeug nach Afrika|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\n",
+ "|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\n",
+ "|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\n",
+ "|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\n",
+ "|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\n",
+ "|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\n",
+ "| translate English to French: A good cast| Une bonne étoile|\n",
+ "|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\n",
+ "|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\n",
+ "|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\n",
+ "| translate English to French: Simply awful| Tout simplement terrible|\n",
+ "|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\n",
+ "| translate English to French: Well| Eh bien|\n",
+ "|translate English to French: This is a new appr...| Il s’agit d’une nouvelle approche de la comédie.|\n",
+ "|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\n",
+ "|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\n",
+ "|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\n",
+ "|translate English to French: 1983's \"Frightmare...|Le film \"Frightmare\" de 1983 est un petit film ...|\n",
+ "|translate English to French: 'Felony' is a B-movie| 'Felony' est un mouvement B|\n",
+ "|translate English to French: This movie defines...| Ce film définit le mot «confus»|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
}
],
"source": [
- "preds.show(truncate=60)"
+ "preds.show(truncate=50)"
]
},
{
- "cell_type": "code",
- "execution_count": 22,
- "id": "1eb0c83b-d91b-4f8c-a5e7-c35f55c88108",
+ "cell_type": "markdown",
"metadata": {},
- "outputs": [],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df2 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to French: \")).select(\"input\").limit(100).cache()"
+ "Let's try English to German:"
]
},
{
"cell_type": "code",
"execution_count": 23,
- "id": "054f94fd-fe79-41e7-b1c7-6124083acc72",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to French: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to French: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to French: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to French: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to French: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to French: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to French: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to French: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to French: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to French: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to French: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to French: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to French: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to French: OK, lets start with the best|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to French: C|\n",
- "| Translate English to French: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "df2.show(truncate=120)"
+ "input_df2 = df.select(preprocess(col(\"text\"), \"translate English to German: \").alias(\"input\")).cache()"
]
},
{
"cell_type": "code",
"execution_count": 24,
- "id": "6f6b70f9-188a-402b-9143-78a5788140e4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 33:> (0 + 1) / 1]\r"
+ "[Stage 36:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.46 ms, sys: 2.2 ms, total: 4.67 ms\n",
- "Wall time: 7.38 s\n"
+ "CPU times: user 4.69 ms, sys: 1.9 ms, total: 6.59 ms\n",
+ "Wall time: 4.01 s\n"
]
},
{
@@ -798,29 +933,28 @@
"source": [
"%%time\n",
"# first pass caches model/fn\n",
- "preds = df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "preds = input_df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
"result = preds.collect()"
]
},
{
"cell_type": "code",
"execution_count": 25,
- "id": "031a6a5e-7999-4653-b394-19ed478d8c96",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 35:> (0 + 1) / 1]\r"
+ "[Stage 39:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.34 ms, sys: 1.13 ms, total: 4.47 ms\n",
- "Wall time: 6.1 s\n"
+ "CPU times: user 5.58 ms, sys: 549 μs, total: 6.13 ms\n",
+ "Wall time: 3.69 s\n"
]
},
{
@@ -833,29 +967,28 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(\"input\"))\n",
+ "preds = input_df2.withColumn(\"preds\", generate(\"input\"))\n",
"result = preds.collect()"
]
},
{
"cell_type": "code",
"execution_count": 26,
- "id": "229b6515-82f6-4e9c-90f0-a9c3cfb26301",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 37:> (0 + 1) / 1]\r"
+ "[Stage 42:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1.72 ms, sys: 2.89 ms, total: 4.6 ms\n",
- "Wall time: 5.93 s\n"
+ "CPU times: user 4.77 ms, sys: 2.29 ms, total: 7.07 ms\n",
+ "Wall time: 3.69 s\n"
]
},
{
@@ -868,472 +1001,406 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "preds = input_df2.withColumn(\"preds\", generate(col(\"input\")))\n",
"result = preds.collect()"
]
},
{
"cell_type": "code",
"execution_count": 27,
- "id": "8be750ac-fa39-452e-bb4c-c2270bc2f70d",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 39:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to French: This is so overly clichéd yo...| Vous ne pouvez pas en tirer d'un tel cliché|\n",
- "|Translate English to French: I am a big fan of The ABC Mo...| Je suis un grand fan du genre The ABC Movies of the Week|\n",
- "|Translate English to French: In the early 1990's \"Step-by...| Au début des années 1990, «Step-by-Step» a été une|\n",
- "|Translate English to French: When The Spirits Within was ...|Lorsque The Spirits Within a été publié, tout ce que vous...|\n",
- "|Translate English to French: I like to think of myself as...| Je me considère comme un mauvais réalisateur de films|\n",
- "|Translate English to French: This film did well at the bo...| Ce film a bien avancé à la salle de cinéma et|\n",
- "|Translate English to French: Following the pleasingly atm...|Après l'original agréablement atmosphérique et la seconde...|\n",
- "|Translate English to French: I like CKY and Viva La Bam, ...| Je m'aime CKY et Viva La Bam,|\n",
- "|Translate English to French: I have read all of the revie...| J'ai lu tous les commentaires pour ce film direct à vidéo|\n",
- "|Translate English to French: Yes, it was an awful movie, ...| Oui, c'était un film terrible, mais il y avait une chanson|\n",
- "|Translate English to French: This was the most uninterest...| Ce fut le plus inquiétant et le plus inquiétant d'h|\n",
- "|Translate English to French: I don't know if this excepti...|Je ne sais pas si ce film extrêmement tacheté était desti...|\n",
- "|Translate English to French: Heart of Darkness Movie Revi...| Un livre connu pour son éloquence et ses concepts complexes|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie| Un mauvais film ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that thi...|En plus du fait que ce film a été réalisé (je suppose quil s|\n",
- "|Translate English to French: Watching this movie, you jus...|Vous devez simplement vous demander : « Que pense-t-il? » Il|\n",
- "| Translate English to French: OK, lets start with the best| OK, s'il y a lieu de commencer par le meilleur|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) ...|Anna Christie (Greta Garbo) retourne pour voir son père C...|\n",
- "| Translate English to French: C| C|\n",
- "|Translate English to French: Tom and Jerry are transporti...| Tom et Jerry transportent des marchandises par avion en|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to German: The only reason I'...|Der einzige Grund, warum ich sogar diesen Film ...|\n",
+ "|translate English to German: Awkward disaster m...|Awkward-Katastrophenmischmash hat ein Team von ...|\n",
+ "|translate English to German: Here is a fantasti...|Hier ist ein fantastisches Konzept für einen Fi...|\n",
+ "|translate English to German: I walked out of th...|Ich ging aus dem Kino, nachdem ich diesen Film ...|\n",
+ "|translate English to German: A wildly uneven fi...|Ein völlig ungleicher Film, in dem das Hauptpro...|\n",
+ "|translate English to German: Leonard Rossiter a...|Leonard Rossiter und Frances de la Tour tragen ...|\n",
+ "| translate English to German: A good cast| Gutes Casting|\n",
+ "|translate English to German: Yet again, I appea...|Ich scheine wieder einmal die einzige Person au...|\n",
+ "|translate English to German: As a serious horro...|Als ernsthafter Horrorfan erhalte ich, dass bes...|\n",
+ "|translate English to German: Upon writing this ...|Ich habe Schwierigkeiten, mich an die Regeln zu...|\n",
+ "| translate English to German: Simply awful| Einfach schrecklich|\n",
+ "|translate English to German: I am a fan of Ed H...|Ich bin ein Fan von Ed Harris' Arbeit und hatte...|\n",
+ "| translate English to German: Well| Nun|\n",
+ "|translate English to German: This is a new appr...| Das ist ein neuer Ansatz für die Komödie|\n",
+ "|translate English to German: It's been mentione...|Es wurde von anderen erwähnt, die unangenehme D...|\n",
+ "|translate English to German: One of the most bo...|Einer der langwierigen Filme, die ich jemals du...|\n",
+ "|translate English to German: This movie was pla...|Dieser Film spielte im letzten Monat auf Lifeti...|\n",
+ "|translate English to German: 1983's \"Frightmare...| 1983 ist \"Frightmare\" ein merkwürdiger Film|\n",
+ "|translate English to German: 'Felony' is a B-movie| 'Felony' ist ein B-Film|\n",
+ "|translate English to German: This movie defines...| Dieser Film definiert das Wort \"verwirrt\"|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
}
],
"source": [
- "preds.show(truncate=60)"
+ "preds.show(truncate=50)"
]
},
{
"cell_type": "markdown",
- "id": "bcabb2a8-3880-46ec-8e01-5a10f71fe83d",
- "metadata": {},
- "source": [
- "### Using Triton Inference Server\n",
- "\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
- ]
- },
- {
- "cell_type": "markdown",
- "id": "5d98fa52-7665-49bf-865a-feec86effe23",
- "metadata": {},
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a79a6f3a-cc34-46a4-aadd-16870423fffa",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n huggingface-torch -c conda-forge python=3.10.0\n",
- "conda activate huggingface-torch\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
"\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers\n",
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
"\n",
- "conda-pack # huggingface-torch.tar.gz\n",
- "```"
+ ""
]
},
{
"cell_type": "code",
"execution_count": 28,
- "id": "b858cf85-82e6-41ef-905b-d8c5d6fea492",
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1e73757e-a451-4835-98e0-257ccf7a9025",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [],
"source": [
- "import os"
+ "from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": 29,
- "id": "05ce7c77-d562-45e8-89bb-cd656aba5a5f",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/hf_generation_torch models\n",
- "\n",
- "# add custom execution environment\n",
- "cp huggingface-torch.tar.gz models"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "a552865c-5dad-4f25-8834-f41e253ac2f6",
"metadata": {
- "tags": []
- },
- "source": [
- "#### Start Triton Server on each executor"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "id": "afd00b7e-8150-4c95-a2e4-037e9c90f92a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[True]"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
},
- "execution_count": 30,
- "metadata": {},
- "output_type": "execute_result"
+ "inputWidgets": {},
+ "nuid": "71b1cb49-3d8f-4eeb-937a-c0c334bd2947",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
}
- ],
+ },
+ "outputs": [],
"source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "huggingface_cache_dir = \"{}/.cache/huggingface\".format(os.path.expanduser('~'))\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
- "\n",
- "def start_triton(it):\n",
- " import docker\n",
+ "def triton_server(ports):\n",
" import time\n",
- " import tritonclient.grpc as grpcclient\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import torch\n",
+ " from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
+ "\n",
+ " print(f\"SERVER: Initializing Conditional Generation model on worker {TaskContext.get().partitionId()}.\")\n",
+ " tokenizer = T5Tokenizer.from_pretrained(\"t5-small\")\n",
+ " model = T5ForConditionalGeneration.from_pretrained(\"t5-small\")\n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " environment=[\n",
- " \"TRANSFORMERS_CACHE=/cache\"\n",
+ " DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
+ " print(f\"SERVER: Using {DEVICE} device.\")\n",
+ " model = model.to(DEVICE)\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = np.squeeze(inputs[\"text\"]).tolist()\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
+ " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
+ " inputs = tokenizer(decoded_sentences,\n",
+ " padding=True,\n",
+ " return_tensors=\"pt\").to(DEVICE)\n",
+ " output_ids = model.generate(input_ids=inputs[\"input_ids\"],\n",
+ " attention_mask=inputs[\"attention_mask\"],\n",
+ " max_length=128)\n",
+ " outputs = np.array([[tokenizer.decode(o, skip_special_tokens=True)] for o in output_ids])\n",
+ " return {\n",
+ " \"translations\": outputs,\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"ConditionalGeneration\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"translations\", dtype=object, shape=(-1,)),\n",
" ],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"1G\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " huggingface_cache_dir: {\"bind\": \"/cache\", \"mode\": \"rw\"}\n",
- " }\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=64,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
" )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
"\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " ready = False\n",
- " while not ready:\n",
- " try:\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " time.sleep(5)\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
"\n",
- " return [True]\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
"\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
]
},
{
"cell_type": "markdown",
- "id": "528d2df6-49fc-4be7-a534-a087dfe31c84",
- "metadata": {},
- "source": [
- "#### Run inference"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "id": "1a997c33-5202-466d-8304-b8c30f32978f",
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1bf14846-15a3-4bc8-b0c5-ce71680d3550",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
- "outputs": [],
"source": [
- "import pandas as pd\n",
- "from functools import partial\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, pandas_udf, struct\n",
- "from pyspark.sql.types import StringType"
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
]
},
{
"cell_type": "code",
- "execution_count": 32,
- "id": "9dea1875-6b95-4fc0-926d-a625a441b33d",
+ "execution_count": 30,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "e2c40df2-161b-483d-9d10-e462ecfb9fed",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [],
"source": [
- "# only use first N examples, since this is slow\n",
- "df = spark.read.parquet(\"imdb_test\").limit(100).cache()"
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
]
},
{
- "cell_type": "code",
- "execution_count": 33,
- "id": "5d6c54e7-534d-406f-b8e6-fd592efd0ab2",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "# only use first sentence and add prefix for conditional generation\n",
- "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
- " @pandas_udf(\"string\")\n",
- " def _preprocess(text: pd.Series) -> pd.Series:\n",
- " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n",
- " return _preprocess(text)"
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
]
},
{
"cell_type": "code",
- "execution_count": 34,
- "id": "dc1bbbe3-4232-49e5-80f6-99976524b73b",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 31,
+ "metadata": {},
"outputs": [],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df1 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to German: \")).select(\"input\").limit(100)"
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
]
},
{
"cell_type": "code",
- "execution_count": 35,
- "id": "5d10c61c-6102-4d19-8dd6-0c7b5b65343e",
+ "execution_count": 32,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "5bf1fafc-d9c9-4fd7-901d-da97cf4ff496",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to German: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to German: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to German: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to German: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to German: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to German: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to German: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to German: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to German: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to German: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to German: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to German: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to German: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to German: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to German: OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to German: C|\n",
- "| Translate English to German: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
]
}
],
"source": [
- "df1.show(truncate=120)"
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
]
},
{
- "cell_type": "code",
- "execution_count": 36,
- "id": "2e0907da-a5d9-4c3b-9db4-ce5e70ca9bb4",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "metadata": {},
"source": [
- "def triton_fn(triton_uri, model_name):\n",
- " import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool8),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
- " return predict"
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
- "execution_count": 37,
- "id": "9308bdd7-6f67-484d-8b51-dd1e1b2960ba",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 33,
+ "metadata": {},
"outputs": [],
"source": [
- "generate = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"hf_generation_torch\"),\n",
- " return_type=StringType(),\n",
- " input_tensor_shapes=[[1]],\n",
- " batch_size=100)"
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
]
},
{
"cell_type": "code",
- "execution_count": 38,
- "id": "38484ffd-370d-492b-8ca4-9eff9f242a9f",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 34,
+ "metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 45:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 4.61 ms, sys: 1.26 ms, total: 5.87 ms\n",
- "Wall time: 2.04 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "Using ports [7000, 7001, 7002]\n"
]
}
],
"source": [
- "%%time\n",
- "# first pass caches model/fn\n",
- "preds = df1.withColumn(\"preds\", generate(struct(\"input\")))\n",
- "results = preds.collect()"
+ "model_name = \"ConditionalGeneration\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
]
},
{
"cell_type": "code",
- "execution_count": 39,
- "id": "ebcb6699-3ac2-4529-ab0f-fab0a5e792da",
+ "execution_count": 35,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "289b08fa-7916-44ea-8fe5-28821451db6b",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 47:> (0 + 1) / 1]\r"
+ "[Stage 46:> (0 + 1) / 1]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.16 ms, sys: 641 μs, total: 3.8 ms\n",
- "Wall time: 1.58 s\n"
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2569775\n",
+ "}\n"
]
},
{
@@ -1345,190 +1412,219 @@
}
],
"source": [
- "%%time\n",
- "preds = df1.withColumn(\"preds\", generate(\"input\"))\n",
- "results = preds.collect()"
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
]
},
{
"cell_type": "code",
- "execution_count": 40,
- "id": "e2ed18ad-d00b-472c-b2c3-047932f2105d",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "url = f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "e203eb19-166d-4177-aa87-fd31b7e3c90e",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 49:> (0 + 1) / 1]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 1.91 ms, sys: 2.38 ms, total: 4.29 ms\n",
- "Wall time: 1.75 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
+ " import numpy as np\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
+ "\n",
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " flattened = np.squeeze(inputs).tolist() \n",
+ " # Encode batch\n",
+ " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
+ " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
+ " # Run inference\n",
+ " result_data = client.infer_batch(encoded_batch_np)\n",
+ " result_data = np.squeeze(result_data[\"translations\"], -1)\n",
+ " return result_data\n",
+ " \n",
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "1b6b2a05-aea4-4e4d-a87d-0a6bd5ab554c",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
}
- ],
+ },
"source": [
- "%%time\n",
- "preds = df1.withColumn(\"preds\", generate(col(\"input\")))\n",
- "results = preds.collect()"
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 41,
- "id": "0cd64a1c-beb8-47d5-ac6f-e8525bb61176",
+ "execution_count": 38,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "a5e83230-5178-4fec-bba2-0e69be40e68c",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to German: This is so overly clichéd yo...| Das ist so übertrieben klischeehaft, dass Sie es nach den|\n",
- "|Translate English to German: I am a big fan of The ABC Mo...| Ich bin ein großer Fan von The ABC Movies of the Week|\n",
- "|Translate English to German: In the early 1990's \"Step-by...| Anfang der 1990er Jahre kam \"Step-by-Step\" als müh|\n",
- "|Translate English to German: When The Spirits Within was ...|Als The Spirits Within veröffentlicht wurde, hörten Sie v...|\n",
- "|Translate English to German: I like to think of myself as...| Ich halte mich gerne als schlechter Filmliebhaber|\n",
- "|Translate English to German: This film did well at the bo...|Dieser Film hat sich gut an der Boxoffice ereignet, und d...|\n",
- "|Translate English to German: Following the pleasingly atm...|Nach dem erfreulich stimmungsvollen Original und dem amüsant|\n",
- "|Translate English to German: I like CKY and Viva La Bam, ...| Ich mag CKY und Viva La Bam, also konnte ich mich nicht|\n",
- "|Translate English to German: I have read all of the revie...| Ich habe alle Rezensionen zu diesem direkten Film gelesen.|\n",
- "|Translate English to German: Yes, it was an awful movie, ...| Ja, es war ein schrecklicher Film, aber es gab|\n",
- "|Translate English to German: This was the most uninterest...|Dies war der größte Horrorfilm, den ich bisher gesehen habe.|\n",
- "|Translate English to German: I don't know if this excepti...|Ich weiß nicht, ob dieser außergewöhnlich langweilige Fil...|\n",
- "|Translate English to German: Heart of Darkness Movie Revi...|Herz der Dunkelheit Film Review Kann ein Buch, das für se...|\n",
- "| Translate English to German: A bad movie ABOUT a bad movie| Ein schlechter Film ABOUT a bad movie|\n",
- "|Translate English to German: Apart from the fact that thi...| Dieser Film wurde zwar fertiggestellt, aber es schien mir|\n",
- "|Translate English to German: Watching this movie, you jus...|Wenn man diesen Film anschaut, muss man einfach fragen: W...|\n",
- "| Translate English to German: OK, lets start with the best| OK, lets start with the best|\n",
- "|Translate English to German: Anna Christie (Greta Garbo) ...| Anna Christie (Greta Garbo) kehrt nach 15 Jahren zurück,|\n",
- "| Translate English to German: C| C|\n",
- "|Translate English to German: Tom and Jerry are transporti...|Tom und Jerry transportieren Güter über Flugzeug nach Afrika|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
+ "outputs": [],
+ "source": [
+ "def preprocess(text: pd.Series, prefix: str = \"\") -> pd.Series:\n",
+ " @pandas_udf(\"string\")\n",
+ " def _preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([prefix + s.split(\".\")[0] for s in text])\n",
+ " return _preprocess(text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "aad299b0-34bb-4edb-b1e4-cd0c82bb7455",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
}
- ],
+ },
+ "outputs": [],
"source": [
- "preds.show(truncate=60)"
+ "df = spark.read.parquet(data_path).limit(512).repartition(8)"
]
},
{
"cell_type": "code",
- "execution_count": 42,
- "id": "af70fed8-0f2b-4ea7-841c-476afdf9b1c0",
+ "execution_count": 40,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "7934a6fc-57bc-4104-a52c-076351e77cbe",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/10 00:12:21 WARN CacheManager: Asked to cache already cached data.\n"
+ "25/01/06 18:28:55 WARN CacheManager: Asked to cache already cached data.\n"
]
}
],
"source": [
- "# only use first 100 rows, since generation takes a while\n",
- "df2 = df.withColumn(\"input\", preprocess(col(\"lines\"), \"Translate English to French: \")).select(\"input\").limit(100).cache()"
+ "input_df = df.select(preprocess(col(\"text\"), \"translate English to French: \").alias(\"input\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "#### Run Inference"
]
},
{
"cell_type": "code",
- "execution_count": 43,
- "id": "ef075e10-e22c-4236-9e0b-cb47cf2d3d06",
+ "execution_count": 41,
"metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| input|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| Translate English to French: This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| Translate English to French: I am a big fan of The ABC Movies of the Week genre|\n",
- "|Translate English to French: In the early 1990's \"Step-by-Step\" came as a tedious combination of the ultra-cheesy \"Fu...|\n",
- "|Translate English to French: When The Spirits Within was released, all you heard from Final Fantasy fans was how awfu...|\n",
- "| Translate English to French: I like to think of myself as a bad movie connoisseur|\n",
- "|Translate English to French: This film did well at the box office, and the producers of this mess thought the stars h...|\n",
- "|Translate English to French: Following the pleasingly atmospheric original and the amusingly silly second one, this i...|\n",
- "| Translate English to French: I like CKY and Viva La Bam, so I couldn't resist this when I saw it for £1|\n",
- "| Translate English to French: I have read all of the reviews for this direct to video movie|\n",
- "|Translate English to French: Yes, it was an awful movie, but there was a song near the beginning of the movie, I thin...|\n",
- "| Translate English to French: This was the most uninteresting horror flick I have seen to date|\n",
- "|Translate English to French: I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'T...|\n",
- "|Translate English to French: Heart of Darkness Movie Review Could a book that is well known for its eloquent wording ...|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that this film was made ( I suppose it seemed a good idea at the tim...|\n",
- "|Translate English to French: Watching this movie, you just have to ask: What were they thinking? There are so many no...|\n",
- "| Translate English to French: OK, lets start with the best|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) returns to see her father Chris (George F Marion) after 15 y...|\n",
- "| Translate English to French: C|\n",
- "| Translate English to French: Tom and Jerry are transporting goods via airplane to Africa|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "be692f4a-cf86-4cf4-9530-7c62e479cacd",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
}
- ],
+ },
+ "outputs": [],
"source": [
- "df2.show(truncate=120)"
+ "generate = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
+ " return_type=StringType(),\n",
+ " input_tensor_shapes=[[1]],\n",
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 44,
- "id": "2e7e4af8-b815-4375-b851-8368309ee8e1",
+ "execution_count": 42,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "0f6229ef-01c8-43c9-a259-c5df6a18d689",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 55:> (0 + 1) / 1]\r"
+ "[Stage 50:====================================> (5 + 3) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.4 ms, sys: 2.75 ms, total: 6.14 ms\n",
- "Wall time: 1.96 s\n"
+ "CPU times: user 8.72 ms, sys: 1.56 ms, total: 10.3 ms\n",
+ "Wall time: 6.27 s\n"
]
},
{
@@ -1541,33 +1637,41 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(struct(\"input\")))\n",
+ "# first pass caches model/fn\n",
+ "preds = input_df.withColumn(\"preds\", generate(struct(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 45,
- "id": "7b0aefb0-a96b-4791-a23c-1ce9b24eb20c",
+ "execution_count": 43,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "5a543b4c-8b29-4f61-9773-2639bbc7f728",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 57:> (0 + 1) / 1]\r"
+ "[Stage 53:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.76 ms, sys: 897 μs, total: 4.66 ms\n",
- "Wall time: 1.61 s\n"
+ "CPU times: user 7.54 ms, sys: 0 ns, total: 7.54 ms\n",
+ "Wall time: 4.27 s\n"
]
},
{
@@ -1580,33 +1684,40 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(\"input\"))\n",
+ "preds = input_df.withColumn(\"preds\", generate(\"input\"))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 46,
- "id": "1214b75b-a373-4579-b4c6-0cb8627da776",
+ "execution_count": 44,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "4c0cfc4e-ef0a-435e-9fdf-72b72b6def93",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 59:> (0 + 1) / 1]\r"
+ "[Stage 56:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.61 ms, sys: 2.26 ms, total: 4.87 ms\n",
- "Wall time: 1.67 s\n"
+ "CPU times: user 5.12 ms, sys: 1.02 ms, total: 6.13 ms\n",
+ "Wall time: 4.2 s\n"
]
},
{
@@ -1619,77 +1730,107 @@
],
"source": [
"%%time\n",
- "preds = df2.withColumn(\"preds\", generate(col(\"input\")))\n",
+ "preds = input_df.withColumn(\"preds\", generate(col(\"input\")))\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 47,
- "id": "c9dbd21f-9e37-4221-b765-80ba8c80b884",
+ "execution_count": 45,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "2d756e2e-8b60-43cb-b5f9-e27de11be24d",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| input| preds|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|Translate English to French: This is so overly clichéd yo...| Vous ne pouvez pas en tirer d'un tel cliché|\n",
- "|Translate English to French: I am a big fan of The ABC Mo...| Je suis un grand fan du genre The ABC Movies of the Week|\n",
- "|Translate English to French: In the early 1990's \"Step-by...| Au début des années 1990, «Step-by-Step» a été une|\n",
- "|Translate English to French: When The Spirits Within was ...|Lorsque The Spirits Within a été publié, tout ce que vous...|\n",
- "|Translate English to French: I like to think of myself as...| Je me considère comme un mauvais réalisateur de films|\n",
- "|Translate English to French: This film did well at the bo...| Ce film a bien avancé à la salle de cinéma et|\n",
- "|Translate English to French: Following the pleasingly atm...|Après l'original agréablement atmosphérique et la seconde...|\n",
- "|Translate English to French: I like CKY and Viva La Bam, ...| Je m'aime CKY et Viva La Bam,|\n",
- "|Translate English to French: I have read all of the revie...| J'ai lu tous les commentaires pour ce film direct à vidéo|\n",
- "|Translate English to French: Yes, it was an awful movie, ...| Oui, c'était un film terrible, mais il y avait une chanson|\n",
- "|Translate English to French: This was the most uninterest...| Ce fut le plus inquiétant et le plus inquiétant d'h|\n",
- "|Translate English to French: I don't know if this excepti...|Je ne sais pas si ce film extrêmement tacheté était desti...|\n",
- "|Translate English to French: Heart of Darkness Movie Revi...| Un livre connu pour son éloquence et ses concepts complexes|\n",
- "| Translate English to French: A bad movie ABOUT a bad movie| Un mauvais film ABOUT a bad movie|\n",
- "|Translate English to French: Apart from the fact that thi...|En plus du fait que ce film a été réalisé (je suppose quil s|\n",
- "|Translate English to French: Watching this movie, you jus...|Vous devez simplement vous demander : « Que pense-t-il? » Il|\n",
- "| Translate English to French: OK, lets start with the best| OK, s'il y a lieu de commencer par le meilleur|\n",
- "|Translate English to French: Anna Christie (Greta Garbo) ...|Anna Christie (Greta Garbo) retourne pour voir son père C...|\n",
- "| Translate English to French: C| C|\n",
- "|Translate English to French: Tom and Jerry are transporti...| Tom et Jerry transportent des marchandises par avion en|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| preds|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|translate English to French: The only reason I'...|La seule raison pour laquelle je donne même ce ...|\n",
+ "|translate English to French: Awkward disaster m...|La mishmash d’Awkward a eu une équipe de scaven...|\n",
+ "|translate English to French: Here is a fantasti...|Voici un concept fantastique pour un film : une...|\n",
+ "|translate English to French: I walked out of th...|Je me suis rendu du cinéma après avoir subi ce ...|\n",
+ "|translate English to French: A wildly uneven fi...|Un film extrêmement inégal où le problème majeu...|\n",
+ "|translate English to French: Leonard Rossiter a...|Leonard Rossiter et Frances de la Tour mettent ...|\n",
+ "| translate English to French: A good cast| Une bonne étoile|\n",
+ "|translate English to French: Yet again, I appea...|Encore une fois, je semble être la seule person...|\n",
+ "|translate English to French: As a serious horro...|En tant que grand fan d'horreur, je peux obteni...|\n",
+ "|translate English to French: Upon writing this ...|la suite de cette étude, j'ai de la difficulté ...|\n",
+ "| translate English to French: Simply awful| Tout simplement terrible|\n",
+ "|translate English to French: I am a fan of Ed H...|Je suis un fan de l'oeuvre d'Ed Harris et j'ai ...|\n",
+ "| translate English to French: Well| Eh bien|\n",
+ "|translate English to French: This is a new appr...| Il s’agit d’une nouvelle approche de la comédie.|\n",
+ "|translate English to French: It's been mentione...|Il a été mentionné par d'autres le dialogue ina...|\n",
+ "|translate English to French: One of the most bo...|Un des films les plus ennuyeux que je n'ai jama...|\n",
+ "|translate English to French: This movie was pla...|Ce film jouait sur Lifetime Movie Network le mo...|\n",
+ "|translate English to French: 1983's \"Frightmare...|Le film \"Frightmare\" de 1983 est un petit film ...|\n",
+ "|translate English to French: 'Felony' is a B-movie| 'Felony' est un mouvement B|\n",
+ "|translate English to French: This movie defines...| Ce film définit le mot «confus»|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
- "preds.show(truncate=60)"
+ "preds.show(truncate=50)"
]
},
{
"cell_type": "markdown",
- "id": "919e3113-64dd-482a-9233-6607b3f63c1e",
"metadata": {
- "tags": []
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "86ae68d4-57da-41d9-91b4-625ef9465d60",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"source": [
- "#### Stop Triton Server on each executor"
+ "#### Shut down servers on each executor"
]
},
{
"cell_type": "code",
- "execution_count": 48,
- "id": "425d3b28-7705-45ba-8a18-ad34fc895219",
+ "execution_count": 46,
"metadata": {
- "tags": [
- "TRITON"
- ]
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "16fd4601-f6d5-4ddf-9b5e-d918ab0adf3a",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -1703,32 +1844,39 @@
"[True]"
]
},
- "execution_count": 48,
+ "execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
+ " \n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 49,
- "id": "2dec80ca-7a7c-46a9-97c0-7afb1572f5b9",
+ "execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@@ -1738,13 +1886,40 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "f43118ab-fc0a-4f64-a126-4302e615654a",
- "metadata": {},
+ "metadata": {
+ "application/vnd.databricks.v1+cell": {
+ "cellMetadata": {
+ "byteLimit": 2048000,
+ "rowLimit": 10000
+ },
+ "inputWidgets": {},
+ "nuid": "008c3e50-d321-4431-a9ab-919b35d1b042",
+ "showTitle": false,
+ "tableResultSettingsMap": {},
+ "title": ""
+ }
+ },
"outputs": [],
"source": []
}
],
"metadata": {
+ "application/vnd.databricks.v1+notebook": {
+ "dashboards": [],
+ "environmentMetadata": null,
+ "language": "python",
+ "notebookMetadata": {
+ "mostRecentlyExecutedCommandWithImplicitDF": {
+ "commandId": 421988607303514,
+ "dataframes": [
+ "_sqldf"
+ ]
+ },
+ "pythonIndentUnit": 4
+ },
+ "notebookName": "spark-triton-db.ipynb",
+ "widgets": {}
+ },
"kernelspec": {
"display_name": "spark-dl-torch",
"language": "python",
@@ -1764,5 +1939,5 @@
}
},
"nbformat": 4,
- "nbformat_minor": 5
+ "nbformat_minor": 4
}
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_tf/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_tf/1/model.py
deleted file mode 100644
index b788c8930..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_tf/1/model.py
+++ /dev/null
@@ -1,150 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- import tensorflow as tf
- # Enable GPU memory growth
- gpus = tf.config.experimental.list_physical_devices('GPU')
- if gpus:
- try:
- for gpu in gpus:
- tf.config.experimental.set_memory_growth(gpu, True)
- except RuntimeError as e:
- print(e)
-
- print(tf.__version__)
-
- from transformers import AutoTokenizer, TFT5ForConditionalGeneration
-
- self.tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
- self.model = TFT5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- output_config = pb_utils.get_output_config_by_name(model_config, "output")
-
- # Convert Triton types to numpy types
- self.output_dtype = pb_utils.triton_string_to_numpy(output_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- output_dtype = self.output_dtype
-
- responses = []
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- sentence_input = pb_utils.get_input_tensor_by_name(request, "input")
- sentences = list(sentence_input.as_numpy())
- sentences = np.squeeze(sentences, -1).tolist()
- sentences = [s.decode('utf-8') for s in sentences]
-
- input_ids = self.tokenizer(sentences,
- padding="longest",
- max_length=512,
- truncation=True,
- return_tensors="tf").input_ids
- output_ids = self.model.generate(input_ids, max_length=20)
- outputs = np.array([self.tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- output_tensor = pb_utils.Tensor("output", outputs.astype(output_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_tf/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_tf/config.pbtxt
deleted file mode 100644
index 88b87130f..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_tf/config.pbtxt
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "hf_generation_tf"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "input"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-output [
- {
- name: "output"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../huggingface-tf.tar.gz"}
-}
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_torch/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_torch/1/model.py
deleted file mode 100644
index 8e9604daa..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_torch/1/model.py
+++ /dev/null
@@ -1,144 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- import torch
- print("torch: {}".format(torch.__version__))
- print("cuda: {}".format(torch.cuda.is_available()))
-
- import transformers
- print("transformers: {}".format(transformers.__version__))
-
- from transformers import T5Tokenizer, T5ForConditionalGeneration
- self.tokenizer = T5Tokenizer.from_pretrained("t5-small")
- self.model = T5ForConditionalGeneration.from_pretrained("t5-small")
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- output_config = pb_utils.get_output_config_by_name(model_config, "output")
-
- # Convert Triton types to numpy types
- self.output_dtype = pb_utils.triton_string_to_numpy(output_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- output_dtype = self.output_dtype
-
- responses = []
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- sentence_input = pb_utils.get_input_tensor_by_name(request, "input")
- sentences = list(sentence_input.as_numpy())
- sentences = np.squeeze(sentences, -1).tolist()
- sentences = [s.decode('utf-8') for s in sentences]
-
- input_ids = self.tokenizer(sentences,
- padding="longest",
- max_length=512,
- truncation=True,
- return_tensors="pt").input_ids
- output_ids = self.model.generate(input_ids, max_length=20)
- outputs = np.array([self.tokenizer.decode(o, skip_special_tokens=True) for o in output_ids])
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- output_tensor = pb_utils.Tensor("output", outputs.astype(output_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_torch/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_torch/config.pbtxt
deleted file mode 100644
index 47db54680..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_generation_torch/config.pbtxt
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "hf_generation_torch"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "input"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-output [
- {
- name: "output"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../huggingface-torch.tar.gz"}
-}
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_tf/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_tf/1/model.py
deleted file mode 100644
index 2a1bfda61..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_tf/1/model.py
+++ /dev/null
@@ -1,147 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- import tensorflow as tf
- print("tf: {}".format(tf.__version__))
-
- # Enable GPU memory growth
- gpus = tf.config.experimental.list_physical_devices('GPU')
- if gpus:
- try:
- for gpu in gpus:
- tf.config.experimental.set_memory_growth(gpu, True)
- except RuntimeError as e:
- print(e)
-
- from transformers import pipeline
- self.pipe = pipeline("sentiment-analysis", device=0)
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- label_config = pb_utils.get_output_config_by_name(model_config, "label")
- score_config = pb_utils.get_output_config_by_name(model_config, "score")
-
- # Convert Triton types to numpy types
- self.label_dtype = pb_utils.triton_string_to_numpy(label_config['data_type'])
- self.score_dtype = pb_utils.triton_string_to_numpy(score_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- label_dtype = self.label_dtype
- score_dtype = self.score_dtype
-
- responses = []
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- sentence_input = pb_utils.get_input_tensor_by_name(request, "sentence")
- sentences = [s.decode('utf-8') for s in sentence_input.as_numpy().flatten()]
-
- results = self.pipe(sentences)
-
- label = np.array([res['label'] for res in results])
- score = np.array([res['score'] for res in results])
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- label_tensor = pb_utils.Tensor("label", label.astype(label_dtype))
- score_tensor = pb_utils.Tensor("score", score.astype(score_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[label_tensor, score_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_tf/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_tf/config.pbtxt
deleted file mode 100644
index df7082ca4..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_tf/config.pbtxt
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "hf_pipeline_tf"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "sentence"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-output [
- {
- name: "label"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "score"
- data_type: TYPE_FP32
- dims: [1]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../huggingface-tf.tar.gz"}
-}
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_torch/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_torch/1/model.py
deleted file mode 100644
index f01886c91..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_torch/1/model.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- import torch
- print("torch: {}".format(torch.__version__))
- print("cuda: {}".format(torch.cuda.is_available()))
-
- import transformers
- print("transformers: {}".format(transformers.__version__))
-
- from transformers import pipeline
- self.pipe = pipeline("sentiment-analysis", device=0)
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- label_config = pb_utils.get_output_config_by_name(model_config, "label")
- score_config = pb_utils.get_output_config_by_name(model_config, "score")
-
- # Convert Triton types to numpy types
- self.label_dtype = pb_utils.triton_string_to_numpy(label_config['data_type'])
- self.score_dtype = pb_utils.triton_string_to_numpy(score_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- label_dtype = self.label_dtype
- score_dtype = self.score_dtype
-
- responses = []
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- sentence_input = pb_utils.get_input_tensor_by_name(request, "sentence")
- sentences = [s.decode('utf-8') for s in sentence_input.as_numpy().flatten()]
-
- results = self.pipe(sentences)
-
- label = np.array([res['label'] for res in results])
- score = np.array([res['score'] for res in results])
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- label_tensor = pb_utils.Tensor("label", label.astype(label_dtype))
- score_tensor = pb_utils.Tensor("score", score.astype(score_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[label_tensor, score_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_torch/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_torch/config.pbtxt
deleted file mode 100644
index 4e54607d2..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_pipeline_torch/config.pbtxt
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "hf_pipeline_torch"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "sentence"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-output [
- {
- name: "label"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "score"
- data_type: TYPE_FP32
- dims: [1]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../huggingface-torch.tar.gz"}
-}
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_transformer_torch/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_transformer_torch/1/model.py
deleted file mode 100644
index f49805deb..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_transformer_torch/1/model.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- import torch
- print("torch: {}".format(torch.__version__))
- print("cuda: {}".format(torch.cuda.is_available()))
-
- import transformers
- print("transformers: {}".format(transformers.__version__))
-
- from sentence_transformers import SentenceTransformer
- self.model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- embedding_config = pb_utils.get_output_config_by_name(model_config, "embedding")
-
- # Convert Triton types to numpy types
- self.embedding_dtype = pb_utils.triton_string_to_numpy(embedding_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- embedding_dtype = self.embedding_dtype
-
- responses = []
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- sentence_input = pb_utils.get_input_tensor_by_name(request, "sentence")
- sentences = list(sentence_input.as_numpy())
- sentences = np.squeeze(sentences, -1).tolist()
- sentences = [s.decode('utf-8') for s in sentences]
-
- embedding = self.model.encode(sentences)
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- embedding_tensor = pb_utils.Tensor("embedding", embedding.astype(embedding_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[embedding_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_transformer_torch/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_transformer_torch/config.pbtxt
deleted file mode 100644
index 798cf4fc7..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/models_config/hf_transformer_torch/config.pbtxt
+++ /dev/null
@@ -1,52 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "hf_transformer_torch"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "sentence"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-output [
- {
- name: "embedding"
- data_type: TYPE_FP32
- dims: [384]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../huggingface-torch.tar.gz"}
-}
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_tf.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_tf.ipynb
index dcba0be87..4163e3fa0 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_tf.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_tf.ipynb
@@ -2,13 +2,16 @@
"cells": [
{
"cell_type": "markdown",
- "id": "60f7ac5d-4a95-4170-a0ac-a7faac9d9ef4",
+ "id": "9e9fe848",
"metadata": {},
"source": [
+ "\n",
+ "\n",
"# PySpark Huggingface Inferencing\n",
- "### Text Classification using Pipelines with Tensorflow\n",
+ "### Sentiment Analysis using Pipelines with Tensorflow\n",
"\n",
- "Based on: https://huggingface.co/docs/transformers/quicktour#pipeline-usage"
+ "In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis. \n",
+ "From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage"
]
},
{
@@ -16,14 +19,12 @@
"id": "1799fd4f",
"metadata": {},
"source": [
- "### Using TensorFlow\n",
- "Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) \n",
- "This notebook does not demonstrate inference with TensorRT, as [TF-TRT](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#tensorrt-10) does not yet support `tf=2.17.0`. See the `pytorch` notebooks for TensorRT demos."
+ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) "
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "0dd0f77b-ee1b-4477-a038-d25a4f1da0ea",
"metadata": {},
"outputs": [
@@ -31,35 +32,39 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 16:47:48.209366: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
- "2024-10-03 16:47:48.215921: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
- "2024-10-03 16:47:48.223519: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
- "2024-10-03 16:47:48.225906: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
- "2024-10-03 16:47:48.231640: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "2025-01-10 23:14:44.169405: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2025-01-10 23:14:44.176824: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "2025-01-10 23:14:44.184715: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "2025-01-10 23:14:44.186998: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2025-01-10 23:14:44.193224: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2024-10-03 16:47:48.625790: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ "2025-01-10 23:14:44.594518: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"source": [
"import tensorflow as tf\n",
- "from transformers import pipeline"
+ "from transformers import pipeline\n",
+ "\n",
+ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
+ "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
+ "import os\n",
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"id": "d80fc3f8",
"metadata": {},
"outputs": [],
"source": [
- "# set device if tensorflow gpu is available\n",
"device = 0 if tf.config.list_physical_devices('GPU') else -1"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "e60a2877",
"metadata": {},
"outputs": [
@@ -72,8 +77,6 @@
}
],
"source": [
- "print(tf.__version__)\n",
- "\n",
"# Enable GPU memory growth\n",
"gpus = tf.config.experimental.list_physical_devices('GPU')\n",
"if gpus:\n",
@@ -81,12 +84,14 @@
" for gpu in gpus:\n",
" tf.config.experimental.set_memory_growth(gpu, True)\n",
" except RuntimeError as e:\n",
- " print(e)"
+ " print(e)\n",
+ "\n",
+ "print(tf.__version__)"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "553b28d2-a5d1-4d07-8a49-8f82b808e738",
"metadata": {},
"outputs": [
@@ -95,8 +100,14 @@
"output_type": "stream",
"text": [
"No model was supplied, defaulted to distilbert/distilbert-base-uncased-finetuned-sst-2-english and revision 714eb0f (https://huggingface.co/distilbert/distilbert-base-uncased-finetuned-sst-2-english).\n",
- "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
- "2024-10-03 16:47:49.863791: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46447 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
+ "Using a pipeline without specifying a model name and revision in production is not recommended.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-01-10 23:14:46.051516: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 45948 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n",
"All PyTorch model weights were used when initializing TFDistilBertForSequenceClassification.\n",
"\n",
"All the weights of TFDistilBertForSequenceClassification were initialized from the PyTorch model.\n",
@@ -110,7 +121,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"id": "3b91fe91-b725-4564-ae93-56e3fb51e47c",
"metadata": {},
"outputs": [
@@ -120,7 +131,7 @@
"[{'label': 'POSITIVE', 'score': 0.9997794032096863}]"
]
},
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -131,7 +142,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "0be39eb3-462c-42ff-b8f4-09f4e4fe3a3c",
"metadata": {},
"outputs": [
@@ -152,15 +163,15 @@
},
{
"cell_type": "markdown",
- "id": "30c90100",
+ "id": "e29ee6d8",
"metadata": {},
"source": [
- "#### Use another model and tokenizer in the pipeline"
+ "Let's try a different model and tokenizer in the pipeline."
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 7,
"id": "cd9d3349",
"metadata": {},
"outputs": [],
@@ -170,7 +181,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"id": "99e21b58",
"metadata": {},
"outputs": [
@@ -178,10 +189,9 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "Some layers from the model checkpoint at nlptown/bert-base-multilingual-uncased-sentiment were not used when initializing TFBertForSequenceClassification: ['dropout_37']\n",
- "- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
- "- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
- "All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at nlptown/bert-base-multilingual-uncased-sentiment.\n",
+ "All PyTorch model weights were used when initializing TFBertForSequenceClassification.\n",
+ "\n",
+ "All the weights of TFBertForSequenceClassification were initialized from the PyTorch model.\n",
"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.\n"
]
}
@@ -195,179 +205,362 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"id": "31079133",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "[{'label': '5 stars', 'score': 0.7272655963897705}]"
+ "[{'label': '5 stars', 'score': 0.7272477746009827}]"
]
},
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "classifier = pipeline(\"sentiment-analysis\", model=model, tokenizer=tokenizer)\n",
+ "classifier = pipeline(\"sentiment-analysis\", model=model, tokenizer=tokenizer, device=device)\n",
"classifier(\"Nous sommes très heureux de vous présenter la bibliothèque 🤗 Transformers.\")"
]
},
{
"cell_type": "markdown",
- "id": "ae92b15e-0da0-46c3-81a3-fabaedbfc42c",
+ "id": "e6357234",
"metadata": {},
"source": [
- "## Inference using Spark DL API"
+ "## PySpark"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"id": "69dd6a1a-f450-47f0-9dbf-ad250585a011",
"metadata": {},
"outputs": [],
"source": [
- "import os\n",
- "import pandas as pd\n",
"from pyspark.sql.functions import col, struct, pandas_udf\n",
"from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.types import FloatType, StringType, StructField, StructType\n",
+ "from pyspark.sql.types import *\n",
"from pyspark.sql import SparkSession\n",
"from pyspark import SparkConf"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "6e0e0dd7",
+ "execution_count": 11,
+ "id": "287b1e96",
"metadata": {},
"outputs": [],
"source": [
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ "import os\n",
+ "import json\n",
+ "import pandas as pd\n",
+ "import datasets\n",
+ "from datasets import load_dataset\n",
+ "datasets.disable_progress_bars()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "50e124cd",
+ "metadata": {},
+ "source": [
+ "Check the cluster environment to handle any platform-specific Spark configurations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "36001f55",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "48c7271a",
+ "metadata": {},
+ "source": [
+ "#### Create Spark Session\n",
"\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "6e0e0dd7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/10 23:14:48 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/10 23:14:48 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "Setting default log level to \"WARN\".\n",
+ "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
+ "25/01/10 23:14:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ ]
+ }
+ ],
+ "source": [
"conf = SparkConf()\n",
+ "\n",
"if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " \n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " source = \"/usr/lib/x86_64-linux-gnu/libstdc++.so.6\"\n",
+ " target = f\"{conda_env}/lib/libstdc++.so.6\"\n",
+ " try:\n",
+ " if os.path.islink(target) or os.path.exists(target):\n",
+ " os.remove(target)\n",
+ " os.symlink(source, target)\n",
+ " except OSError as e:\n",
+ " print(f\"Error creating symlink: {e}\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
"spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
"sc = spark.sparkContext"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 14,
"id": "42d70208",
"metadata": {},
"outputs": [],
"source": [
- "from datasets import load_dataset\n",
- "\n",
- "# Load the IMDB dataset\n",
- "data = load_dataset(\"imdb\", split=\"test\")\n",
- "\n",
- "lines = []\n",
- "for example in data:\n",
- " # first sentence only\n",
- " lines.append([example[\"text\"]])\n",
- "\n",
- "len(lines)\n",
- "\n",
- "df = spark.createDataFrame(lines, ['lines']).repartition(8).cache()"
+ "dataset = load_dataset(\"imdb\", split=\"test\")\n",
+ "dataset = dataset.to_pandas().drop(columns=\"label\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "95ded4b2",
+ "metadata": {},
+ "source": [
+ "#### Create PySpark DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 15,
"id": "ac24f3c2",
"metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "StructType([StructField('text', StringType(), True)])"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = spark.createDataFrame(dataset).repartition(8)\n",
+ "df.schema"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "1db4db3a",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/03 16:47:58 WARN TaskSetManager: Stage 0 contains a task of very large size (3860 KiB). The maximum recommended task size is 1000 KiB.\n",
" \r"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "25000"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "df.write.mode(\"overwrite\").parquet(\"imdb_test\")"
+ "df.count()"
]
},
{
"cell_type": "code",
- "execution_count": 15,
- "id": "9665b7b6-d7e9-4bd4-b29d-7a449ac5b574",
+ "execution_count": 17,
+ "id": "517fe2e9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "25/01/06 20:44:11 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
]
},
+ {
+ "data": {
+ "text/plain": [
+ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.take(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "e176d28b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 20:44:12 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "395e0374",
+ "metadata": {},
+ "source": [
+ "#### Load and preprocess DataFrame\n",
+ "\n",
+ "Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "9665b7b6-d7e9-4bd4-b29d-7a449ac5b574",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@pandas_udf(\"string\")\n",
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "26693020",
+ "metadata": {},
+ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+--------------------------------------------------------------------------------+\n",
- "| sentence|\n",
- "+--------------------------------------------------------------------------------+\n",
- "| |\n",
- "|Hard up, No proper jobs going down at the pit, why not rent your kids! DIY pi...|\n",
- "|I watched this movie to see the direction one of the most promising young tal...|\n",
- "| This movie makes you wish imdb would let you vote a zero|\n",
- "|I never want to see this movie again!
Not only is it dreadfully ba...|\n",
- "|(As a note, I'd like to say that I saw this movie at my annual church camp, w...|\n",
- "| Don't get me wrong, I love the TV series of League Of Gentlemen|\n",
- "|Did you ever think, like after watching a horror movie with a group of friend...|\n",
- "| Awful, awful, awful|\n",
- "|This movie seems a little clunky around the edges, like not quite enough zani...|\n",
- "|I rented this movie hoping that it would provide some good entertainment and ...|\n",
- "|Well, where to start describing this celluloid debacle? You already know the ...|\n",
- "| I hoped for this show to be somewhat realistic|\n",
- "| All I have to say is one word|\n",
- "|Honestly awful film, bad editing, awful lighting, dire dialog and scrappy scr...|\n",
- "|This critique tells the story of 4 little friends who went to watch Angels an...|\n",
- "| This review contains a partial spoiler|\n",
- "| I'm rather surprised that anybody found this film touching or moving|\n",
- "| If you like bad movies (and you must to watch this one) here's a good one|\n",
- "|This is really bad, the characters were bland, the story was boring, and ther...|\n",
- "+--------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| input|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n",
+ "| There were two things I hated about WASTED : The directing and the script |\n",
+ "| I'm rather surprised that anybody found this film touching or moving|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n",
+ "| This movie has been done before|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get involved in such mindles...|\n",
+ "| There is not one character on this sitcom with any redeeming qualities|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n",
+ "| My wife rented this movie and then conveniently never got to see it|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n",
+ "| you will likely be sorely disappointed by this sequel that's not a sequel|\n",
+ "| If I was British, I would be embarrassed by this portrayal of incompetence|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n",
+ "| This show is like watching someone who is in training to someday host a show|\n",
+ "| Sigh|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "# only use first sentence of IMDB reviews\n",
- "@pandas_udf(\"string\")\n",
- "def first_sentence(text: pd.Series) -> pd.Series:\n",
- " return pd.Series([s.split(\".\")[0] for s in text])\n",
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n",
+ "df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "76dc525c",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
"\n",
- "df = spark.read.parquet(\"imdb_test\").withColumn(\"sentence\", first_sentence(col(\"lines\"))).select(\"sentence\").limit(100).cache()\n",
- "df.show(truncate=80)"
+ "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 21,
"id": "0da9d25c-5ebe-4503-bb19-154fcc047cbf",
"metadata": {},
"outputs": [],
@@ -394,7 +587,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 22,
"id": "78afef29-ee30-4267-9fb6-be2dcb86cbba",
"metadata": {},
"outputs": [],
@@ -404,12 +597,12 @@
" StructField(\"label\", StringType(), True),\n",
" StructField(\"score\", FloatType(), True)\n",
" ]),\n",
- " batch_size=10)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 23,
"id": "a5bc327e-89cf-4731-82e6-e66cb93deef1",
"metadata": {},
"outputs": [
@@ -417,15 +610,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 11:> (0 + 1) / 1]\r"
+ "[Stage 18:==============> (2 + 6) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 9.15 ms, sys: 6.76 ms, total: 15.9 ms\n",
- "Wall time: 5 s\n"
+ "CPU times: user 10.4 ms, sys: 3.63 ms, total: 14.1 ms\n",
+ "Wall time: 4.56 s\n"
]
},
{
@@ -438,14 +631,15 @@
],
"source": [
"%%time\n",
+ "# first pass caches model/fn\n",
"# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(struct(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 24,
"id": "ac642895-cfd6-47ee-9b21-02e7835424e4",
"metadata": {},
"outputs": [
@@ -453,15 +647,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 13:> (0 + 1) / 1]\r"
+ "[Stage 21:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 4.86 ms, sys: 2.19 ms, total: 7.05 ms\n",
- "Wall time: 2.81 s\n"
+ "CPU times: user 2.93 ms, sys: 3.98 ms, total: 6.91 ms\n",
+ "Wall time: 1.38 s\n"
]
},
{
@@ -474,14 +668,13 @@
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(\"sentence\")).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 25,
"id": "76a44d80-d5db-405f-989c-7246379cfb95",
"metadata": {},
"outputs": [
@@ -489,15 +682,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 15:> (0 + 1) / 1]\r"
+ "[Stage 24:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 3.91 ms, sys: 1.96 ms, total: 5.87 ms\n",
- "Wall time: 2.76 s\n"
+ "CPU times: user 3.23 ms, sys: 2.55 ms, total: 5.77 ms\n",
+ "Wall time: 1.39 s\n"
]
},
{
@@ -510,14 +703,13 @@
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(col(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 26,
"id": "c01761b3-c766-46b0-ae0b-fcf968ffb3a1",
"metadata": {},
"outputs": [
@@ -526,32 +718,39 @@
"output_type": "stream",
"text": [
"+--------------------------------------------------------------------------------+--------+----------+\n",
- "| sentence| label| score|\n",
+ "| input| label| score|\n",
"+--------------------------------------------------------------------------------+--------+----------+\n",
- "| |POSITIVE|0.74807304|\n",
- "|Hard up, No proper jobs going down at the pit, why not rent your kids! DIY pi...|NEGATIVE| 0.9996724|\n",
- "|I watched this movie to see the direction one of the most promising young tal...|POSITIVE| 0.9994948|\n",
- "| This movie makes you wish imdb would let you vote a zero|NEGATIVE| 0.9981299|\n",
- "|I never want to see this movie again!
Not only is it dreadfully ba...|NEGATIVE|0.99883264|\n",
- "|(As a note, I'd like to say that I saw this movie at my annual church camp, w...|POSITIVE| 0.9901753|\n",
- "| Don't get me wrong, I love the TV series of League Of Gentlemen|POSITIVE|0.99983096|\n",
- "|Did you ever think, like after watching a horror movie with a group of friend...|POSITIVE| 0.9992768|\n",
- "| Awful, awful, awful|NEGATIVE| 0.9997433|\n",
- "|This movie seems a little clunky around the edges, like not quite enough zani...|NEGATIVE| 0.9996525|\n",
- "|I rented this movie hoping that it would provide some good entertainment and ...|NEGATIVE|0.99643254|\n",
- "|Well, where to start describing this celluloid debacle? You already know the ...|NEGATIVE|0.99973005|\n",
- "| I hoped for this show to be somewhat realistic|POSITIVE| 0.8417903|\n",
- "| All I have to say is one word|NEGATIVE|0.97844803|\n",
- "|Honestly awful film, bad editing, awful lighting, dire dialog and scrappy scr...|NEGATIVE| 0.9997701|\n",
- "|This critique tells the story of 4 little friends who went to watch Angels an...|POSITIVE| 0.9942386|\n",
- "| This review contains a partial spoiler|NEGATIVE|0.99620205|\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984061|\n",
+ "| There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979007|\n",
"| I'm rather surprised that anybody found this film touching or moving|POSITIVE|0.83874947|\n",
- "| If you like bad movies (and you must to watch this one) here's a good one|POSITIVE| 0.9936475|\n",
- "|This is really bad, the characters were bland, the story was boring, and ther...|NEGATIVE|0.99953806|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99727434|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE| 0.982114|\n",
+ "| This movie has been done before|NEGATIVE|0.94210696|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE| 0.9967818|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985843|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926835|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get invo...|NEGATIVE|0.99956733|\n",
+ "| There is not one character on this sitcom with any redeeming qualities|NEGATIVE| 0.9985662|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.994562|\n",
+ "| My wife rented this movie and then conveniently never got to see it|NEGATIVE|0.99841607|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953243|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\n",
+ "| you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE| 0.9997198|\n",
+ "| If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE| 0.9965172|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE| 0.9986059|\n",
+ "| This show is like watching someone who is in training to someday host a show|NEGATIVE|0.97015846|\n",
+ "| Sigh|NEGATIVE| 0.9923151|\n",
"+--------------------------------------------------------------------------------+--------+----------+\n",
"only showing top 20 rows\n",
"\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
@@ -560,289 +759,428 @@
},
{
"cell_type": "markdown",
- "id": "eb826fde-99d9-43fe-8ddc-f5acbe76b4e9",
+ "id": "fc8127d9",
"metadata": {},
"source": [
- "### Using Triton Inference Server\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
"\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment. "
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
+ "\n",
+ ""
]
},
{
- "cell_type": "markdown",
- "id": "10368010-f94d-4167-91a1-2cf9ed91a2c9",
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "4d4be844-4b8c-47df-bd09-0c280c7ff16b",
"metadata": {},
+ "outputs": [],
"source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n huggingface-tf -c conda-forge python=3.10.0\n",
- "conda activate huggingface-tf\n",
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "7e53df9f-43cb-4c38-b8ac-dc2cbad99815",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_server(ports):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import tensorflow as tf\n",
+ " from transformers import pipeline\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
"\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 tensorflow[and-cuda] tf-keras transformers conda-pack\n",
+ " print(f\"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.\")\n",
+ " # Enable GPU memory growth\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ " \n",
+ " device = 0 if tf.config.list_physical_devices('GPU') else -1\n",
+ " \n",
+ " pipe = pipeline(\"sentiment-analysis\", device=device)\n",
+ " print(f\"SERVER: Using {device} device.\")\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = np.squeeze(inputs[\"text\"]).tolist()\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
+ " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
+ " return {\n",
+ " \"outputs\": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"SentimentAnalysis\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"outputs\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=64,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "19d9028d",
+ "metadata": {},
+ "source": [
+ "#### Start Triton servers\n",
"\n",
- "conda-pack # huggingface-tf.tar.gz\n",
- "```"
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
]
},
{
"cell_type": "code",
- "execution_count": 22,
- "id": "4d4be844-4b8c-47df-bd09-0c280c7ff16b",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 29,
+ "id": "144acb8e-4c08-40fc-a9ed-f721c409ee68",
+ "metadata": {},
"outputs": [],
"source": [
- "import numpy as np\n",
- "import pandas as pd\n",
- "import os\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct, pandas_udf\n",
- "from pyspark.sql.types import FloatType, StringType, StructField, StructType"
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f85dc27",
+ "metadata": {},
+ "source": [
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
]
},
{
"cell_type": "code",
- "execution_count": 23,
- "id": "7e53df9f-43cb-4c38-b8ac-dc2cbad99815",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 30,
+ "id": "714e6ef9",
+ "metadata": {},
"outputs": [],
"source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/hf_pipeline_tf models\n",
- "\n",
- "# add custom execution environment\n",
- "cp huggingface-tf.tar.gz models"
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "156de815",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
]
},
{
"cell_type": "markdown",
- "id": "db4a5b06-126a-4bc4-baae-a45ea30832a7",
- "metadata": {
- "tags": []
- },
+ "id": "736ac5f4",
+ "metadata": {},
"source": [
- "#### Start Triton Server on each executor"
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
- "execution_count": 24,
- "id": "144acb8e-4c08-40fc-a9ed-f721c409ee68",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 32,
+ "id": "f368460c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "4b6044f9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using ports [7000, 7001, 7002]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name = \"SentimentAnalysis\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "f75c30c5",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "[Stage 28:> (0 + 1) / 1]\r"
]
},
{
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2714026\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "huggingface_cache_dir = \"{}/.cache/huggingface\".format(os.path.expanduser('~'))\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
- "\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " environment=[\n",
- " \"TRANSFORMERS_CACHE=/cache\"\n",
- " ],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"256M\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " huggingface_cache_dir: {\"bind\": \"/cache\", \"mode\": \"rw\"}\n",
- " }\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " \n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " \n",
- " elapsed = 0\n",
- " timeout = 120\n",
- " ready = False\n",
- " while not ready and elapsed < timeout:\n",
- " try:\n",
- " time.sleep(5)\n",
- " elapsed += 5\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " pass\n",
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e4c4017c",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "35bf6939",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "url = f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "431b864c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
+ " import numpy as np\n",
+ " from pytriton.client import ModelClient\n",
"\n",
- " return [True]\n",
+ " print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
"\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " flattened = np.squeeze(inputs).tolist()\n",
+ " # Encode batch\n",
+ " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
+ " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
+ " # Run inference\n",
+ " result_data = client.infer_batch(encoded_batch_np)\n",
+ " result_data = np.squeeze(result_data[\"outputs\"], -1)\n",
+ " return [json.loads(o) for o in result_data]\n",
+ " \n",
+ " return infer_batch"
]
},
{
"cell_type": "markdown",
- "id": "c24d77ab-60d3-45eb-a9c2-dc811eca0af4",
+ "id": "5a8ec7be",
"metadata": {},
"source": [
- "#### Run inference"
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 37,
"id": "d53fb283-bf9e-4571-8c68-b75a41f1f067",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
- "# only use first sentence of IMDB reviews\n",
"@pandas_udf(\"string\")\n",
- "def first_sentence(text: pd.Series) -> pd.Series:\n",
- " return pd.Series([s.split(\".\")[0] for s in text])\n",
- "\n",
- "df = spark.read.parquet(\"imdb_test\").withColumn(\"sentence\", first_sentence(col(\"lines\"))).select(\"sentence\").limit(1000)"
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 38,
"id": "29b0cc0d-c480-4e4a-bd41-207dc314cba5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 20:44:27 WARN CacheManager: Asked to cache already cached data.\n"
+ ]
+ }
+ ],
"source": [
- "def triton_fn(triton_uri, model_name):\n",
- " import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
- " return predict"
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "da39990f",
+ "metadata": {},
+ "source": [
+ "#### Run Inference"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 39,
"id": "3930cfcd-3284-4c6a-a9b5-36b8053fe899",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
- "from functools import partial\n",
- "\n",
- "classify = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"hf_pipeline_tf\"),\n",
+ "classify = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
" return_type=StructType([\n",
" StructField(\"label\", StringType(), True),\n",
" StructField(\"score\", FloatType(), True)\n",
" ]),\n",
" input_tensor_shapes=[[1]],\n",
- " batch_size=100)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 40,
"id": "8eecbf23-4e9e-4d4c-8645-98209b25db2c",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 20:> (0 + 1) / 1]\r"
+ "[Stage 32:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 22.5 ms, sys: 5.9 ms, total: 28.4 ms\n",
- "Wall time: 24.6 s\n"
+ "CPU times: user 6.65 ms, sys: 5.3 ms, total: 11.9 ms\n",
+ "Wall time: 7.44 s\n"
]
},
{
@@ -857,33 +1195,29 @@
"%%time\n",
"# first pass caches model/fn\n",
"# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(struct(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 41,
"id": "566ba28c-0ca4-4479-a24a-c8a362228b89",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 21:> (0 + 1) / 1]\r"
+ "[Stage 35:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 12.2 ms, sys: 10.1 ms, total: 22.3 ms\n",
- "Wall time: 23.8 s\n"
+ "CPU times: user 9.85 ms, sys: 2.25 ms, total: 12.1 ms\n",
+ "Wall time: 7 s\n"
]
},
{
@@ -896,34 +1230,29 @@
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(\"sentence\")).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 42,
"id": "44c7e776-08da-484a-ba07-9d6add1a0f15",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 22:> (0 + 1) / 1]\r"
+ "[Stage 38:==================================================> (7 + 1) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 8.74 ms, sys: 8.23 ms, total: 17 ms\n",
- "Wall time: 23.8 s\n"
+ "CPU times: user 3.63 ms, sys: 7.25 ms, total: 10.9 ms\n",
+ "Wall time: 7.08 s\n"
]
},
{
@@ -936,56 +1265,44 @@
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(col(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": 43,
"id": "f61d79f8-661e-4d9e-a3aa-c0754b854603",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 23:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+\n",
- "|sentence |label |score |\n",
- "+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+\n",
- "| |POSITIVE|0.74807304|\n",
- "|Hard up, No proper jobs going down at the pit, why not rent your kids! DIY pimp story without the gratuitous sex scenes, either hard core or soft core, therefore reads like a public information film from the fifties, give this a wide miss, use a barge pole if you can|NEGATIVE|0.9996724 |\n",
- "|I watched this movie to see the direction one of the most promising young talents in movies was going |POSITIVE|0.9994948 |\n",
- "|This movie makes you wish imdb would let you vote a zero |NEGATIVE|0.9981299 |\n",
- "|I never want to see this movie again!
Not only is it dreadfully bad, but I can't stand seeing my hero Stan Laurel looking so old and sick |NEGATIVE|0.99883264|\n",
- "|(As a note, I'd like to say that I saw this movie at my annual church camp, where the entire youth group laughed at it |POSITIVE|0.9901753 |\n",
- "|Don't get me wrong, I love the TV series of League Of Gentlemen |POSITIVE|0.99983096|\n",
- "|Did you ever think, like after watching a horror movie with a group of friends: \"Wow, this is so cool! We have got to make a splatter horror movie ourselves some day soon |POSITIVE|0.9992768 |\n",
- "|Awful, awful, awful |NEGATIVE|0.9997433 |\n",
- "|This movie seems a little clunky around the edges, like not quite enough zaniness was thrown it when it should have been |NEGATIVE|0.9996525 |\n",
- "|I rented this movie hoping that it would provide some good entertainment and some cool poker knowledge or stories |NEGATIVE|0.99643254|\n",
- "|Well, where to start describing this celluloid debacle? You already know the big fat NADA passing as a plot, so let's jut point out that this is so PC it's offensive |NEGATIVE|0.99973005|\n",
- "|I hoped for this show to be somewhat realistic |POSITIVE|0.8417903 |\n",
- "|All I have to say is one word |NEGATIVE|0.97844803|\n",
- "|Honestly awful film, bad editing, awful lighting, dire dialog and scrappy screenplay |NEGATIVE|0.9997701 |\n",
- "|This critique tells the story of 4 little friends who went to watch Angels and Demons the movie on the first night it came out, even though it was a school night, because \"Angels and Demons is worth it |POSITIVE|0.9942386 |\n",
- "|This review contains a partial spoiler |NEGATIVE|0.99620205|\n",
- "|I'm rather surprised that anybody found this film touching or moving |POSITIVE|0.83874947|\n",
- "|If you like bad movies (and you must to watch this one) here's a good one |POSITIVE|0.9936475 |\n",
- "|This is really bad, the characters were bland, the story was boring, and there is no sex scene |NEGATIVE|0.99953806|\n",
- "+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+\n",
+ "+--------------------------------------------------------------------------------+--------+----------+\n",
+ "| input| label| score|\n",
+ "+--------------------------------------------------------------------------------+--------+----------+\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984061|\n",
+ "| There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979007|\n",
+ "| I'm rather surprised that anybody found this film touching or moving|POSITIVE|0.83874947|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99727434|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE| 0.982114|\n",
+ "| This movie has been done before|NEGATIVE|0.94210696|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE| 0.9967818|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985843|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926835|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get invo...|NEGATIVE|0.99956733|\n",
+ "| There is not one character on this sitcom with any redeeming qualities|NEGATIVE| 0.9985662|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.994562|\n",
+ "| My wife rented this movie and then conveniently never got to see it|NEGATIVE|0.99841607|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953243|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\n",
+ "| you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE| 0.9997198|\n",
+ "| If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE| 0.9965172|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE| 0.9986059|\n",
+ "| This show is like watching someone who is in training to someday host a show|NEGATIVE|0.97015846|\n",
+ "| Sigh|NEGATIVE| 0.9923151|\n",
+ "+--------------------------------------------------------------------------------+--------+----------+\n",
"only showing top 20 rows\n",
"\n"
]
@@ -999,29 +1316,30 @@
}
],
"source": [
- "preds.show(truncate=False)"
+ "preds.show(truncate=80)"
]
},
{
"cell_type": "markdown",
- "id": "e197c146-1794-47f0-bcd9-7e8d8ab8625f",
- "metadata": {
- "tags": []
- },
+ "id": "fac2ae57",
+ "metadata": {},
"source": [
- "#### Stop Triton Server on each executor"
+ "#### Shut down server on each executor"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 44,
"id": "425d3b28-7705-45ba-8a18-ad34fc895219",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -1035,31 +1353,39 @@
"[True]"
]
},
- "execution_count": 32,
+ "execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
+ " \n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 45,
"id": "9f19643c-4ee4-44f2-b762-2078c0c8eba9",
"metadata": {},
"outputs": [],
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_torch.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_torch.ipynb
index 1e99ed365..ffdd037a8 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_torch.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/pipelines_torch.ipynb
@@ -5,15 +5,18 @@
"id": "60f7ac5d-4a95-4170-a0ac-a7faac9d9ef4",
"metadata": {},
"source": [
+ "\n",
+ "\n",
"# PySpark Huggingface Inferencing\n",
- "### Text Classification using Pipelines with PyTorch\n",
+ "### Sentiment Analysis using Pipelines with PyTorch\n",
"\n",
- "Based on: https://huggingface.co/docs/transformers/quicktour#pipeline-usage"
+ "In this notebook, we demonstrate distributed inference with Huggingface Pipelines to perform sentiment analysis. \n",
+ "From: https://huggingface.co/docs/transformers/quicktour#pipeline-usage"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"id": "0dd0f77b-ee1b-4477-a038-d25a4f1da0ea",
"metadata": {},
"outputs": [],
@@ -24,18 +27,17 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"id": "e1f756c6",
"metadata": {},
"outputs": [],
"source": [
- "# set device if gpu is available\n",
- "device = 0 if torch.cuda.is_available() else -1"
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "553b28d2-a5d1-4d07-8a49-8f82b808e738",
"metadata": {},
"outputs": [
@@ -54,7 +56,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "3b91fe91-b725-4564-ae93-56e3fb51e47c",
"metadata": {},
"outputs": [
@@ -64,7 +66,7 @@
"[{'label': 'POSITIVE', 'score': 0.9997795224189758}]"
]
},
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -75,7 +77,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"id": "0be39eb3-462c-42ff-b8f4-09f4e4fe3a3c",
"metadata": {},
"outputs": [
@@ -99,12 +101,12 @@
"id": "f752f929",
"metadata": {},
"source": [
- "#### Use another model and tokenizer in the pipeline"
+ "Let's try a different model and tokenizer in the pipeline."
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"id": "9861865f",
"metadata": {},
"outputs": [],
@@ -114,7 +116,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "506e7834",
"metadata": {},
"outputs": [],
@@ -127,7 +129,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"id": "312017fc",
"metadata": {},
"outputs": [
@@ -137,7 +139,7 @@
"[{'label': '5 stars', 'score': 0.7272652983665466}]"
]
},
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
@@ -152,154 +154,320 @@
"id": "ae92b15e-0da0-46c3-81a3-fabaedbfc42c",
"metadata": {},
"source": [
- "## Inference using Spark DL API"
+ "## PySpark"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"id": "69dd6a1a-f450-47f0-9dbf-ad250585a011",
"metadata": {},
"outputs": [],
"source": [
- "import pandas as pd\n",
"from pyspark.sql.functions import col, struct, pandas_udf\n",
"from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.types import FloatType, StringType, StructField, StructType\n",
+ "from pyspark.sql.types import *\n",
"from pyspark.sql import SparkSession\n",
- "from pyspark import SparkConf\n",
- "import os"
+ "from pyspark import SparkConf"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "id": "6e0e0dd7",
+ "execution_count": 10,
+ "id": "42c19ad8",
"metadata": {},
"outputs": [],
"source": [
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ "import os\n",
+ "import json\n",
+ "import pandas as pd\n",
+ "import datasets\n",
+ "from datasets import load_dataset\n",
+ "datasets.disable_progress_bars()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3f1a0210",
+ "metadata": {},
+ "source": [
+ "Check the cluster environment to handle any platform-specific Spark configurations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "79aaf5ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b99f9c38",
+ "metadata": {},
+ "source": [
+ "#### Create Spark Session\n",
"\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "6e0e0dd7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 18:29:40 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/06 18:29:40 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "Setting default log level to \"WARN\".\n",
+ "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
+ "25/01/06 18:29:40 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ ]
+ }
+ ],
+ "source": [
"conf = SparkConf()\n",
+ "\n",
"if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
"spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
"sc = spark.sparkContext"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 13,
"id": "42d70208",
"metadata": {},
"outputs": [],
"source": [
- "from datasets import load_dataset\n",
- "\n",
- "# Load the IMDB dataset\n",
- "data = load_dataset(\"imdb\", split=\"test\")\n",
- "\n",
- "lines = []\n",
- "for example in data:\n",
- " # first sentence only\n",
- " lines.append([example[\"text\"]])\n",
- "\n",
- "len(lines)\n",
- "\n",
- "df = spark.createDataFrame(lines, ['lines']).repartition(8).cache()"
+ "dataset = load_dataset(\"imdb\", split=\"test\")\n",
+ "dataset = dataset.to_pandas().drop(columns=\"label\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "de0f421d",
+ "metadata": {},
+ "source": [
+ "#### Create PySpark DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 14,
"id": "ac24f3c2",
"metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "StructType([StructField('text', StringType(), True)])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = spark.createDataFrame(dataset).repartition(8)\n",
+ "df.schema"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "b0d1876b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "25000"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.count()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "06ec6bb6",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/03 16:44:02 WARN TaskSetManager: Stage 0 contains a task of very large size (3860 KiB). The maximum recommended task size is 1000 KiB.\n",
- " \r"
+ "25/01/06 18:29:47 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "df.write.mode(\"overwrite\").parquet(\"imdb_test\")"
+ "df.take(1)"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "9665b7b6-d7e9-4bd4-b29d-7a449ac5b574",
+ "execution_count": 17,
+ "id": "eeadf4e2",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "25/01/06 18:29:47 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
]
- },
+ }
+ ],
+ "source": [
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "09cddc95",
+ "metadata": {},
+ "source": [
+ "#### Load and preprocess DataFrame\n",
+ "\n",
+ "Define our preprocess function. We'll take the first sentence from each sample as our input for sentiment analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "9665b7b6-d7e9-4bd4-b29d-7a449ac5b574",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@pandas_udf(\"string\")\n",
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "74cfa3ff",
+ "metadata": {},
+ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+--------------------------------------------------------------------------------+\n",
- "| sentence|\n",
- "+--------------------------------------------------------------------------------+\n",
- "| |\n",
- "|Hard up, No proper jobs going down at the pit, why not rent your kids! DIY pi...|\n",
- "|I watched this movie to see the direction one of the most promising young tal...|\n",
- "| This movie makes you wish imdb would let you vote a zero|\n",
- "|I never want to see this movie again!
Not only is it dreadfully ba...|\n",
- "|(As a note, I'd like to say that I saw this movie at my annual church camp, w...|\n",
- "| Don't get me wrong, I love the TV series of League Of Gentlemen|\n",
- "|Did you ever think, like after watching a horror movie with a group of friend...|\n",
- "| Awful, awful, awful|\n",
- "|This movie seems a little clunky around the edges, like not quite enough zani...|\n",
- "|I rented this movie hoping that it would provide some good entertainment and ...|\n",
- "|Well, where to start describing this celluloid debacle? You already know the ...|\n",
- "| I hoped for this show to be somewhat realistic|\n",
- "| All I have to say is one word|\n",
- "|Honestly awful film, bad editing, awful lighting, dire dialog and scrappy scr...|\n",
- "|This critique tells the story of 4 little friends who went to watch Angels an...|\n",
- "| This review contains a partial spoiler|\n",
- "| I'm rather surprised that anybody found this film touching or moving|\n",
- "| If you like bad movies (and you must to watch this one) here's a good one|\n",
- "|This is really bad, the characters were bland, the story was boring, and ther...|\n",
- "+--------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| input|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n",
+ "| There were two things I hated about WASTED : The directing and the script |\n",
+ "| I'm rather surprised that anybody found this film touching or moving|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n",
+ "| This movie has been done before|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get involved in such mindles...|\n",
+ "| There is not one character on this sitcom with any redeeming qualities|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n",
+ "| My wife rented this movie and then conveniently never got to see it|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n",
+ "| you will likely be sorely disappointed by this sequel that's not a sequel|\n",
+ "| If I was British, I would be embarrassed by this portrayal of incompetence|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n",
+ "| This show is like watching someone who is in training to someday host a show|\n",
+ "| Sigh|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "# only use first sentence of IMDB reviews\n",
- "@pandas_udf(\"string\")\n",
- "def first_sentence(text: pd.Series) -> pd.Series:\n",
- " return pd.Series([s.split(\".\")[0] for s in text])\n",
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n",
+ "df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1ad92750",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
"\n",
- "df = spark.read.parquet(\"imdb_test\").withColumn(\"sentence\", first_sentence(col(\"lines\"))).select(\"sentence\").limit(100).cache()\n",
- "df.show(truncate=80)"
+ "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 20,
"id": "0da9d25c-5ebe-4503-bb19-154fcc047cbf",
"metadata": {},
"outputs": [],
@@ -308,7 +476,7 @@
" import torch\n",
" from transformers import pipeline\n",
" \n",
- " device = 0 if torch.cuda.is_available() else -1\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" pipe = pipeline(\"sentiment-analysis\", device=device)\n",
" def predict(inputs):\n",
" return pipe(inputs.tolist())\n",
@@ -317,7 +485,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 21,
"id": "78afef29-ee30-4267-9fb6-be2dcb86cbba",
"metadata": {},
"outputs": [],
@@ -327,12 +495,12 @@
" StructField(\"label\", StringType(), True),\n",
" StructField(\"score\", FloatType(), True)\n",
" ]),\n",
- " batch_size=10)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 22,
"id": "a5bc327e-89cf-4731-82e6-e66cb93deef1",
"metadata": {},
"outputs": [
@@ -340,15 +508,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 11:> (0 + 1) / 1]\r"
+ "[Stage 18:==============> (2 + 6) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 12.6 ms, sys: 2.39 ms, total: 15 ms\n",
- "Wall time: 2.02 s\n"
+ "CPU times: user 14.8 ms, sys: 4.23 ms, total: 19 ms\n",
+ "Wall time: 3.15 s\n"
]
},
{
@@ -361,14 +529,15 @@
],
"source": [
"%%time\n",
+ "# first pass caches model/fn\n",
"# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(struct(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 23,
"id": "ac642895-cfd6-47ee-9b21-02e7835424e4",
"metadata": {},
"outputs": [
@@ -376,21 +545,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.13 ms, sys: 1.06 ms, total: 3.19 ms\n",
- "Wall time: 237 ms\n"
+ "CPU times: user 2.59 ms, sys: 2.33 ms, total: 4.91 ms\n",
+ "Wall time: 393 ms\n"
]
}
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(\"sentence\")).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 24,
"id": "76a44d80-d5db-405f-989c-7246379cfb95",
"metadata": {},
"outputs": [
@@ -398,21 +566,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 2.28 ms, sys: 790 μs, total: 3.07 ms\n",
- "Wall time: 230 ms\n"
+ "CPU times: user 2.65 ms, sys: 2.41 ms, total: 5.06 ms\n",
+ "Wall time: 398 ms\n"
]
}
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(col(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 25,
"id": "c01761b3-c766-46b0-ae0b-fcf968ffb3a1",
"metadata": {},
"outputs": [
@@ -421,28 +588,28 @@
"output_type": "stream",
"text": [
"+--------------------------------------------------------------------------------+--------+----------+\n",
- "| sentence| label| score|\n",
+ "| input| label| score|\n",
"+--------------------------------------------------------------------------------+--------+----------+\n",
- "| |POSITIVE| 0.7481212|\n",
- "|Hard up, No proper jobs going down at the pit, why not rent your kids! DIY pi...|NEGATIVE|0.99967253|\n",
- "|I watched this movie to see the direction one of the most promising young tal...|POSITIVE| 0.9994943|\n",
- "| This movie makes you wish imdb would let you vote a zero|NEGATIVE| 0.9981305|\n",
- "|I never want to see this movie again!
Not only is it dreadfully ba...|NEGATIVE| 0.9988337|\n",
- "|(As a note, I'd like to say that I saw this movie at my annual church camp, w...|POSITIVE| 0.9901974|\n",
- "| Don't get me wrong, I love the TV series of League Of Gentlemen|POSITIVE| 0.9998311|\n",
- "|Did you ever think, like after watching a horror movie with a group of friend...|POSITIVE| 0.9992779|\n",
- "| Awful, awful, awful|NEGATIVE| 0.9997433|\n",
- "|This movie seems a little clunky around the edges, like not quite enough zani...|NEGATIVE|0.99965274|\n",
- "|I rented this movie hoping that it would provide some good entertainment and ...|NEGATIVE|0.99642426|\n",
- "|Well, where to start describing this celluloid debacle? You already know the ...|NEGATIVE|0.99973005|\n",
- "| I hoped for this show to be somewhat realistic|POSITIVE| 0.8426496|\n",
- "| All I have to say is one word|NEGATIVE| 0.9784491|\n",
- "|Honestly awful film, bad editing, awful lighting, dire dialog and scrappy scr...|NEGATIVE| 0.99977|\n",
- "|This critique tells the story of 4 little friends who went to watch Angels an...|POSITIVE| 0.9942334|\n",
- "| This review contains a partial spoiler|NEGATIVE| 0.996191|\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before bl...|NEGATIVE| 0.9984042|\n",
+ "| There were two things I hated about WASTED : The directing and the script |NEGATIVE| 0.9979019|\n",
"| I'm rather surprised that anybody found this film touching or moving|POSITIVE| 0.8392794|\n",
- "| If you like bad movies (and you must to watch this one) here's a good one|POSITIVE|0.99366415|\n",
- "|This is really bad, the characters were bland, the story was boring, and ther...|NEGATIVE|0.99953806|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an ac...|NEGATIVE|0.99726933|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw the...|POSITIVE|0.98212516|\n",
+ "| This movie has been done before|NEGATIVE|0.94194806|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for eac...|NEGATIVE|0.99678314|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 1...|NEGATIVE| 0.9985846|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|NEGATIVE|0.99926823|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get invo...|NEGATIVE| 0.9995671|\n",
+ "| There is not one character on this sitcom with any redeeming qualities|NEGATIVE|0.99856514|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.9945687|\n",
+ "| My wife rented this movie and then conveniently never got to see it|NEGATIVE| 0.9984137|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hyste...|NEGATIVE| 0.9953224|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was t...|NEGATIVE| 0.9997607|\n",
+ "| you will likely be sorely disappointed by this sequel that's not a sequel|NEGATIVE|0.99971956|\n",
+ "| If I was British, I would be embarrassed by this portrayal of incompetence|NEGATIVE|0.99651587|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can p...|NEGATIVE|0.99860746|\n",
+ "| This show is like watching someone who is in training to someday host a show|NEGATIVE| 0.970153|\n",
+ "| Sigh|NEGATIVE|0.99231356|\n",
"+--------------------------------------------------------------------------------+--------+----------+\n",
"only showing top 20 rows\n",
"\n"
@@ -455,289 +622,419 @@
},
{
"cell_type": "markdown",
- "id": "eb826fde-99d9-43fe-8ddc-f5acbe76b4e9",
+ "id": "8ba1a6ce",
"metadata": {},
"source": [
- "### Using Triton Inference Server\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
"\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
+ "\n",
+ ""
]
},
{
- "cell_type": "markdown",
- "id": "10368010-f94d-4167-91a1-2cf9ed91a2c9",
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "4d4be844-4b8c-47df-bd09-0c280c7ff16b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "7e53df9f-43cb-4c38-b8ac-dc2cbad99815",
"metadata": {},
+ "outputs": [],
"source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n huggingface-torch -c conda-forge python=3.10.0\n",
- "conda activate huggingface-torch\n",
+ "def triton_server(ports):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import torch\n",
+ " from transformers import pipeline\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
+ "\n",
+ " print(f\"SERVER: Initializing pipeline on worker {TaskContext.get().partitionId()}.\")\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " pipe = pipeline(\"sentiment-analysis\", device=device)\n",
+ " print(f\"SERVER: Using {device} device.\")\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = np.squeeze(inputs[\"text\"]).tolist()\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
+ " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
+ " return {\n",
+ " \"outputs\": np.array([[json.dumps(o)] for o in pipe(decoded_sentences)])\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"SentimentAnalysis\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"outputs\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=64,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
"\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
"\n",
- "conda-pack # huggingface-torch.tar.gz\n",
- "```"
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7c5f4f2d",
+ "metadata": {},
+ "source": [
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
]
},
{
"cell_type": "code",
- "execution_count": 21,
- "id": "4d4be844-4b8c-47df-bd09-0c280c7ff16b",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 28,
+ "id": "1c4f2412",
+ "metadata": {},
"outputs": [],
"source": [
- "import numpy as np\n",
- "import pandas as pd\n",
- "import os\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct, pandas_udf\n",
- "from pyspark.sql.types import FloatType, StringType, StructField, StructType"
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5463c517",
+ "metadata": {},
+ "source": [
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
]
},
{
"cell_type": "code",
- "execution_count": 22,
- "id": "7e53df9f-43cb-4c38-b8ac-dc2cbad99815",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 29,
+ "id": "a4757163",
+ "metadata": {},
"outputs": [],
"source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/hf_pipeline_torch models\n",
- "\n",
- "# add custom execution environment\n",
- "cp huggingface-torch.tar.gz models"
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ad13db78",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
]
},
{
"cell_type": "markdown",
- "id": "db4a5b06-126a-4bc4-baae-a45ea30832a7",
- "metadata": {
- "tags": []
- },
+ "id": "5febf6e8",
+ "metadata": {},
"source": [
- "#### Start Triton Server on each executor"
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
- "execution_count": 23,
- "id": "144acb8e-4c08-40fc-a9ed-f721c409ee68",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 31,
+ "id": "e786e29c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "79a4e9d7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using ports [7000, 7001, 7002]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name = \"SentimentAnalysis\"\n",
+ "\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "7a1a4c4c",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "[Stage 28:> (0 + 1) / 1]\r"
]
},
{
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 23,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2571652\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "huggingface_cache_dir = \"{}/.cache/huggingface\".format(os.path.expanduser('~'))\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
- "\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " environment=[\n",
- " \"TRANSFORMERS_CACHE=/cache\"\n",
- " ],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"256M\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " huggingface_cache_dir: {\"bind\": \"/cache\", \"mode\": \"rw\"}\n",
- " }\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " \n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " \n",
- " elapsed = 0\n",
- " timeout = 120\n",
- " ready = False\n",
- " while not ready and elapsed < timeout:\n",
- " try:\n",
- " time.sleep(5)\n",
- " elapsed += 5\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " pass\n",
- "\n",
- " return [True]\n",
- "\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
]
},
{
"cell_type": "markdown",
- "id": "c24d77ab-60d3-45eb-a9c2-dc811eca0af4",
+ "id": "f5ae0b8e",
"metadata": {},
"source": [
- "#### Run inference"
+ "#### Define client function"
]
},
{
"cell_type": "code",
- "execution_count": 24,
- "id": "d53fb283-bf9e-4571-8c68-b75a41f1f067",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 34,
+ "id": "f6899d96",
+ "metadata": {},
"outputs": [],
"source": [
- "# only use first sentence of IMDB reviews\n",
- "@pandas_udf(\"string\")\n",
- "def first_sentence(text: pd.Series) -> pd.Series:\n",
- " return pd.Series([s.split(\".\")[0] for s in text])\n",
- "\n",
- "df = spark.read.parquet(\"imdb_test\").withColumn(\"sentence\", first_sentence(col(\"lines\"))).select(\"sentence\").limit(1000)"
+ "url = f\"http://localhost:{ports[0]}\""
]
},
{
"cell_type": "code",
- "execution_count": 25,
- "id": "29b0cc0d-c480-4e4a-bd41-207dc314cba5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 35,
+ "id": "14760940",
+ "metadata": {},
"outputs": [],
"source": [
- "def triton_fn(triton_uri, model_name):\n",
+ "def triton_fn(url, model_name):\n",
" import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
+ "\n",
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " flattened = np.squeeze(inputs).tolist()\n",
+ " # Encode batch\n",
+ " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
+ " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
+ " # Run inference\n",
+ " result_data = client.infer_batch(encoded_batch_np)\n",
+ " result_data = np.squeeze(result_data[\"outputs\"], -1)\n",
+ " return [json.loads(o) for o in result_data]\n",
" \n",
- " return predict"
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a741e23a",
+ "metadata": {},
+ "source": [
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 36,
+ "id": "ccc884a4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@pandas_udf(\"string\")\n",
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "c426fdbe",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 18:29:56 WARN CacheManager: Asked to cache already cached data.\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7da06df4",
+ "metadata": {},
+ "source": [
+ "#### Run Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
"id": "3930cfcd-3284-4c6a-a9b5-36b8053fe899",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
- "from functools import partial\n",
- "\n",
- "classify = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"hf_pipeline_torch\"),\n",
+ "classify = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
" return_type=StructType([\n",
" StructField(\"label\", StringType(), True),\n",
" StructField(\"score\", FloatType(), True)\n",
" ]),\n",
" input_tensor_shapes=[[1]],\n",
- " batch_size=100)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 39,
"id": "8eecbf23-4e9e-4d4c-8645-98209b25db2c",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 20:> (0 + 1) / 1]\r"
+ "[Stage 32:===========================================> (6 + 2) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 5.89 ms, sys: 5.41 ms, total: 11.3 ms\n",
- "Wall time: 1.98 s\n"
+ "CPU times: user 16.7 ms, sys: 3.77 ms, total: 20.4 ms\n",
+ "Wall time: 2.58 s\n"
]
},
{
@@ -752,157 +1049,116 @@
"%%time\n",
"# first pass caches model/fn\n",
"# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(struct(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(struct(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 28,
+ "execution_count": 40,
"id": "566ba28c-0ca4-4479-a24a-c8a362228b89",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 21:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 5.87 ms, sys: 2.39 ms, total: 8.26 ms\n",
- "Wall time: 1.87 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 2.77 ms, sys: 0 ns, total: 2.77 ms\n",
+ "Wall time: 462 ms\n"
]
}
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(\"sentence\")).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(\"input\")).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 29,
+ "execution_count": 41,
"id": "44c7e776-08da-484a-ba07-9d6add1a0f15",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 22:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 5.24 ms, sys: 1.13 ms, total: 6.37 ms\n",
- "Wall time: 1.86 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 2.51 ms, sys: 2.71 ms, total: 5.22 ms\n",
+ "Wall time: 461 ms\n"
]
}
],
"source": [
"%%time\n",
- "# note: expanding the \"struct\" return_type to top-level columns\n",
- "preds = df.withColumn(\"preds\", classify(col(\"sentence\"))).select(\"sentence\", \"preds.*\")\n",
+ "preds = df.withColumn(\"preds\", classify(col(\"input\"))).select(\"input\", \"preds.*\")\n",
"results = preds.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 42,
"id": "f61d79f8-661e-4d9e-a3aa-c0754b854603",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+\n",
- "|sentence |label |score |\n",
- "+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+\n",
- "| |POSITIVE|0.7481212 |\n",
- "|Hard up, No proper jobs going down at the pit, why not rent your kids! DIY pimp story without the gratuitous sex scenes, either hard core or soft core, therefore reads like a public information film from the fifties, give this a wide miss, use a barge pole if you can|NEGATIVE|0.99967253|\n",
- "|I watched this movie to see the direction one of the most promising young talents in movies was going |POSITIVE|0.9994943 |\n",
- "|This movie makes you wish imdb would let you vote a zero |NEGATIVE|0.9981305 |\n",
- "|I never want to see this movie again!
Not only is it dreadfully bad, but I can't stand seeing my hero Stan Laurel looking so old and sick |NEGATIVE|0.9988337 |\n",
- "|(As a note, I'd like to say that I saw this movie at my annual church camp, where the entire youth group laughed at it |POSITIVE|0.9901974 |\n",
- "|Don't get me wrong, I love the TV series of League Of Gentlemen |POSITIVE|0.9998311 |\n",
- "|Did you ever think, like after watching a horror movie with a group of friends: \"Wow, this is so cool! We have got to make a splatter horror movie ourselves some day soon |POSITIVE|0.9992779 |\n",
- "|Awful, awful, awful |NEGATIVE|0.9997433 |\n",
- "|This movie seems a little clunky around the edges, like not quite enough zaniness was thrown it when it should have been |NEGATIVE|0.99965274|\n",
- "|I rented this movie hoping that it would provide some good entertainment and some cool poker knowledge or stories |NEGATIVE|0.99642426|\n",
- "|Well, where to start describing this celluloid debacle? You already know the big fat NADA passing as a plot, so let's jut point out that this is so PC it's offensive |NEGATIVE|0.99973005|\n",
- "|I hoped for this show to be somewhat realistic |POSITIVE|0.8426496 |\n",
- "|All I have to say is one word |NEGATIVE|0.9784491 |\n",
- "|Honestly awful film, bad editing, awful lighting, dire dialog and scrappy screenplay |NEGATIVE|0.99977 |\n",
- "|This critique tells the story of 4 little friends who went to watch Angels and Demons the movie on the first night it came out, even though it was a school night, because \"Angels and Demons is worth it |POSITIVE|0.9942334 |\n",
- "|This review contains a partial spoiler |NEGATIVE|0.996191 |\n",
- "|I'm rather surprised that anybody found this film touching or moving |POSITIVE|0.8392794 |\n",
- "|If you like bad movies (and you must to watch this one) here's a good one |POSITIVE|0.99366415|\n",
- "|This is really bad, the characters were bland, the story was boring, and there is no sex scene |NEGATIVE|0.99953806|\n",
- "+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+----------+\n",
+ "+----------------------------------------------------------------------+--------+----------+\n",
+ "| input| label| score|\n",
+ "+----------------------------------------------------------------------+--------+----------+\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from...|NEGATIVE| 0.9984042|\n",
+ "|There were two things I hated about WASTED : The directing and the ...|NEGATIVE| 0.9979019|\n",
+ "| I'm rather surprised that anybody found this film touching or moving|POSITIVE| 0.8392794|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Tra...|NEGATIVE|0.99726933|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event...|POSITIVE|0.98212516|\n",
+ "| This movie has been done before|NEGATIVE|0.94194806|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comme...|NEGATIVE|0.99678314|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch t...|NEGATIVE| 0.9985846|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans ...|NEGATIVE|0.99926823|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actor...|NEGATIVE| 0.9995671|\n",
+ "|There is not one character on this sitcom with any redeeming qualities|NEGATIVE|0.99856514|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|POSITIVE| 0.9945687|\n",
+ "| My wife rented this movie and then conveniently never got to see it|NEGATIVE| 0.9984137|\n",
+ "|This is one of those star-filled over-the-top comedies that could a...|NEGATIVE| 0.9953224|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Cha...|NEGATIVE| 0.9997607|\n",
+ "|you will likely be sorely disappointed by this sequel that's not a ...|NEGATIVE|0.99971956|\n",
+ "|If I was British, I would be embarrassed by this portrayal of incom...|NEGATIVE|0.99651587|\n",
+ "|One of those movies in which there are no big twists whatsoever and...|NEGATIVE|0.99860746|\n",
+ "|This show is like watching someone who is in training to someday ho...|NEGATIVE| 0.970153|\n",
+ "| Sigh|NEGATIVE|0.99231356|\n",
+ "+----------------------------------------------------------------------+--------+----------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
- "preds.show(truncate=False)"
+ "preds.show(truncate=70)"
]
},
{
"cell_type": "markdown",
- "id": "e197c146-1794-47f0-bcd9-7e8d8ab8625f",
- "metadata": {
- "tags": []
- },
+ "id": "2248858c",
+ "metadata": {},
"source": [
- "#### Stop Triton Server on each executor"
+ "#### Shut down server on each executor"
]
},
{
"cell_type": "code",
- "execution_count": 31,
- "id": "425d3b28-7705-45ba-8a18-ad34fc895219",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 43,
+ "id": "e3a4e51f",
+ "metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -916,31 +1172,39 @@
"[True]"
]
},
- "execution_count": 31,
+ "execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
+ " \n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 32,
+ "execution_count": 44,
"id": "9f19643c-4ee4-44f2-b762-2078c0c8eba9",
"metadata": {},
"outputs": [],
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/sentence_transformers_torch.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/sentence_transformers_torch.ipynb
index deac314d0..2cd4e056a 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/sentence_transformers_torch.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/huggingface/sentence_transformers_torch.ipynb
@@ -5,16 +5,19 @@
"id": "777fc40d",
"metadata": {},
"source": [
+ "\n",
+ "\n",
"# PySpark Huggingface Inferencing\n",
- "## Sentence Transformers with PyTorch\n",
+ "### Sentence Transformers with PyTorch\n",
"\n",
+ "In this notebook, we demonstrate distributed inference with the Huggingface SentenceTransformer library for sentence embedding. \n",
"From: https://huggingface.co/sentence-transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
- "id": "731faab7-a700-46f8-bba5-1c8764e5eacb",
+ "id": "c5f0d0a8",
"metadata": {},
"outputs": [
{
@@ -22,27 +25,46 @@
"output_type": "stream",
"text": [
"/home/rishic/anaconda3/envs/spark-dl-torch/lib/python3.11/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:13: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
- " from tqdm.autonotebook import tqdm, trange\n",
+ " from tqdm.autonotebook import tqdm, trange\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from sentence_transformers import SentenceTransformer\n",
+ "\n",
+ "# Manually enable Huggingface tokenizer parallelism to avoid disabling with PySpark parallelism.\n",
+ "# See (https://github.com/huggingface/transformers/issues/5486) for more info. \n",
+ "import os\n",
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "731faab7-a700-46f8-bba5-1c8764e5eacb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
"/home/rishic/anaconda3/envs/spark-dl-torch/lib/python3.11/site-packages/transformers/tokenization_utils_base.py:1617: FutureWarning: `clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be deprecated in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884\n",
" warnings.warn(\n"
]
}
],
"source": [
- "from sentence_transformers import SentenceTransformer\n",
- "model = SentenceTransformer('paraphrase-MiniLM-L6-v2')\n",
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ "model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n",
"\n",
- "#Sentences we want to encode. Example:\n",
"sentence = ['This framework generates embeddings for each input sentence']\n",
- "\n",
- "\n",
- "#Sentences are encoded by calling model.encode()\n",
"embedding = model.encode(sentence)"
]
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"id": "96eea5ca-3cf7-46e3-b40c-598538112d24",
"metadata": {},
"outputs": [
@@ -69,33 +91,68 @@
"## PySpark"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "dbda3e66-005a-4ad0-8017-c1cc7cbf0058",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyspark.sql.types import *\n",
+ "from pyspark import SparkConf\n",
+ "from pyspark.sql import SparkSession\n",
+ "from pyspark.sql.functions import pandas_udf, col, struct\n",
+ "from pyspark.ml.functions import predict_batch_udf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "b525c5c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import json\n",
+ "import pandas as pd\n",
+ "import datasets\n",
+ "from datasets import load_dataset\n",
+ "datasets.disable_progress_bars()"
+ ]
+ },
{
"cell_type": "markdown",
- "id": "e8938317-e31e-4e8d-b2d8-f92c1b5a300c",
+ "id": "58e7c1bc",
"metadata": {},
"source": [
- "## Inference using Spark DL API\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
+ "Check the cluster environment to handle any platform-specific Spark configurations."
]
},
{
"cell_type": "code",
- "execution_count": 3,
- "id": "dbda3e66-005a-4ad0-8017-c1cc7cbf0058",
+ "execution_count": 6,
+ "id": "5a013217",
"metadata": {},
"outputs": [],
"source": [
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct\n",
- "from pyspark.sql.types import ArrayType, FloatType\n",
- "from pyspark.sql import SparkSession\n",
- "from pyspark import SparkConf\n",
- "from datasets import load_dataset"
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ad3c003d",
+ "metadata": {},
+ "source": [
+ "#### Create Spark Session\n",
+ "\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 7,
"id": "23ec67ba",
"metadata": {},
"outputs": [
@@ -103,136 +160,258 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
- "To disable this warning, you can either:\n",
- "\t- Avoid using `tokenizers` before the fork if possible\n",
- "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n",
- "24/10/08 00:19:28 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
- "24/10/08 00:19:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "25/01/06 18:43:02 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/06 18:43:02 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
"Setting default log level to \"WARN\".\n",
"To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
- "24/10/08 00:19:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ "25/01/06 18:43:03 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
"source": [
- "import os\n",
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
- "\n",
"conf = SparkConf()\n",
+ "\n",
"if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
"spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
"sc = spark.sparkContext"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "4cfd1394",
+ "metadata": {},
+ "source": [
+ "Load the IMBD Movie Reviews dataset from Huggingface."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 8,
"id": "9bc1edb5",
"metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = load_dataset(\"imdb\", split=\"test\")\n",
+ "dataset = dataset.to_pandas().drop(columns=\"label\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "59c71bff",
+ "metadata": {},
+ "source": [
+ "#### Create PySpark DataFrame"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "836e5f84-12c6-4c95-838e-53de7e46a20b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "StructType([StructField('text', StringType(), True)])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df = spark.createDataFrame(dataset).repartition(8)\n",
+ "df.schema"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "36703d23-37a3-40df-b09a-c68206d285b6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "25000"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.count()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "1f122ae3",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "25/01/06 18:43:10 WARN TaskSetManager: Stage 6 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[Row(text=\"Anyone remember the first CKY, CKY2K etc..? Back when it was about making crazy cool stuff, rather than watching Bam Margera act like a douchebag, spoiled 5 year old, super/rock-star wannabe.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "# load IMDB reviews (test) dataset and write to parquet\n",
- "data = load_dataset(\"imdb\", split=\"test\")\n",
- "\n",
- "lines = []\n",
- "for example in data:\n",
- " lines.append([example[\"text\"].split(\".\")[0]])\n",
- "\n",
- "len(lines)\n",
+ "df.take(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "14fd59fb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 18:43:10 WARN TaskSetManager: Stage 9 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
"\n",
- "df = spark.createDataFrame(lines, ['lines']).repartition(10)\n",
- "df.schema\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6bb083ec",
+ "metadata": {},
+ "source": [
+ "#### Load and preprocess DataFrame\n",
"\n",
- "df.write.mode(\"overwrite\").parquet(\"imdb_test\")"
+ "Define our preprocess function. We'll take the first sentence from each sample as our input for translation."
]
},
{
"cell_type": "code",
- "execution_count": 6,
- "id": "836e5f84-12c6-4c95-838e-53de7e46a20b",
+ "execution_count": 13,
+ "id": "2510bdd1",
"metadata": {},
"outputs": [],
"source": [
- "# only use first N examples, since this is slow\n",
- "df = spark.read.parquet(\"imdb_test\").limit(100).cache()"
+ "@pandas_udf(\"string\")\n",
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
]
},
{
"cell_type": "code",
- "execution_count": 7,
- "id": "36703d23-37a3-40df-b09a-c68206d285b6",
+ "execution_count": 14,
+ "id": "5bb28548",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| lines|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| This is so overly clichéd you'll want to switch it off after the first 45 minutes|\n",
- "| I was very disappointed by this movie|\n",
- "| I think vampire movies (usually) are wicked|\n",
- "| Though not a complete waste of time, 'Eighteen' really wasn't all sweet as it pretended to be|\n",
- "|This film did well at the box office, and the producers of this mess thought the stars had such good chemistry in thi...|\n",
- "| Peter Crawford discovers a comet on a collision course with the moon|\n",
- "|This tale of the upper-classes getting their come-uppance and wallowing in their high-class misery is like a contempo...|\n",
- "|Words almost fail me to describe how terrible this Irish vanity project (funded by Canadian taxpayers - both federal ...|\n",
- "| This was the most uninteresting horror flick I have seen to date|\n",
- "| Heart of Darkness was terrible|\n",
- "| I saw this movie when it was first released in Pittsburgh Pa|\n",
- "|It was funny because the whole thing was so unrealistic, I mean, come on, like a pop star would just show up at a pub...|\n",
- "|Watching this movie, you just have to ask: What were they thinking? There are so many noticeably bad parts of this mo...|\n",
- "| In a sense, this movie did not even compare to the novel|\n",
- "| Poor Jane Austen ought to be glad she's not around to see this dreadful wreck of an adaptation|\n",
- "| I gave this movie a four-star rating for a few reasons|\n",
- "| It seems that Dee Snyder ran out of ideas halfway through the script|\n",
- "| Now, let me see if I have this correct, a lunatic serial killer is going around murdering estate agents|\n",
- "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n",
- "|First of all, I would like to say that I am a fan of all of the actors that appear in this film and at the time that ...|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| input|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|Doesn't anyone bother to check where this kind of sludge comes from before blathering on about it...|\n",
+ "| There were two things I hated about WASTED : The directing and the script |\n",
+ "| I'm rather surprised that anybody found this film touching or moving|\n",
+ "|Cultural Vandalism Is the new Hallmark production of Gulliver's Travels an act of cultural vandal...|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 year old, and the event I saw then was pretty differe...|\n",
+ "| This movie has been done before|\n",
+ "|[ as a new resolution for this year 2005, i decide to write a comment for each movie I saw in the...|\n",
+ "|This movie is over hyped!! I am sad to say that I manage to watch the first 15 minutes of this mo...|\n",
+ "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did such talented actors get involved in such mindles...|\n",
+ "| There is not one character on this sitcom with any redeeming qualities|\n",
+ "| Tommy Lee Jones was the best Woodroe and no one can play Woodroe F|\n",
+ "| My wife rented this movie and then conveniently never got to see it|\n",
+ "|This is one of those star-filled over-the-top comedies that could a) be hysterical, or b) wish th...|\n",
+ "|This excruciatingly boring and unfunny movie made me think that Chaplin was the real Hitler, as o...|\n",
+ "| you will likely be sorely disappointed by this sequel that's not a sequel|\n",
+ "| If I was British, I would be embarrassed by this portrayal of incompetence|\n",
+ "|One of those movies in which there are no big twists whatsoever and you can predict pretty much w...|\n",
+ "| This show is like watching someone who is in training to someday host a show|\n",
+ "| Sigh|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
}
],
"source": [
- "df.show(truncate=120)"
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()\n",
+ "df.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "014eae88",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 15,
"id": "f780c026-0f3f-4aea-8b61-5b3dbae83fb7",
"metadata": {},
"outputs": [],
"source": [
"def predict_batch_fn():\n",
+ " import torch\n",
" from sentence_transformers import SentenceTransformer\n",
- " model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\")\n",
+ "\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n",
" def predict(inputs):\n",
" return model.encode(inputs.tolist())\n",
" return predict"
@@ -240,19 +419,19 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 16,
"id": "f5c88ddc-ca19-4430-8b0e-b9fae143b237",
"metadata": {},
"outputs": [],
"source": [
"encode = predict_batch_udf(predict_batch_fn,\n",
" return_type=ArrayType(FloatType()),\n",
- " batch_size=10)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 17,
"id": "85344c22-4a4d-4cb0-8771-5836ae2794db",
"metadata": {},
"outputs": [
@@ -260,15 +439,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 9:> (0 + 1) / 1]\r"
+ "[Stage 18:=============================> (4 + 4) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 4.34 ms, sys: 4.15 ms, total: 8.48 ms\n",
- "Wall time: 2.58 s\n"
+ "CPU times: user 10.9 ms, sys: 5.42 ms, total: 16.3 ms\n",
+ "Wall time: 3.58 s\n"
]
},
{
@@ -282,530 +461,606 @@
"source": [
"%%time\n",
"# first pass caches model/fn\n",
- "embeddings = df.withColumn(\"encoding\", encode(struct(\"lines\")))\n",
+ "embeddings = df.withColumn(\"embedding\", encode(struct(\"input\")))\n",
"results = embeddings.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 18,
"id": "c23bb885-6ab0-4471-943d-4c10414100fa",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 11:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1.76 ms, sys: 4.89 ms, total: 6.65 ms\n",
- "Wall time: 2.47 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 1.37 ms, sys: 5.71 ms, total: 7.08 ms\n",
+ "Wall time: 162 ms\n"
]
}
],
"source": [
"%%time\n",
- "embeddings = df.withColumn(\"encoding\", encode(\"lines\"))\n",
+ "embeddings = df.withColumn(\"embedding\", encode(\"input\"))\n",
"results = embeddings.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 19,
"id": "93bc6da3-d853-4233-b805-cb4a46f4f9b9",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 13:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1.55 ms, sys: 6.05 ms, total: 7.6 ms\n",
- "Wall time: 2.46 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 4.15 ms, sys: 4.07 ms, total: 8.21 ms\n",
+ "Wall time: 202 ms\n"
]
}
],
"source": [
"%%time\n",
- "embeddings = df.withColumn(\"encoding\", encode(col(\"lines\")))\n",
+ "embeddings = df.withColumn(\"embedding\", encode(col(\"input\")))\n",
"results = embeddings.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 20,
"id": "2073616f-7151-4760-92f2-441dd0bfe9fe",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 15:> (0 + 1) / 1]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| lines| encoding|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|This is so overly clichéd you'll want to switch it off af...|[-0.06755405, -0.13365394, 0.36675274, -0.2772311, -0.085...|\n",
- "| I was very disappointed by this movie|[-0.05903806, 0.16684641, 0.16768408, 0.10940918, 0.18100...|\n",
- "| I think vampire movies (usually) are wicked|[0.025601083, -0.5308639, -0.319133, -0.013351389, -0.338...|\n",
- "|Though not a complete waste of time, 'Eighteen' really wa...|[0.20991832, 0.5228605, 0.44517252, -0.031682555, -0.4117...|\n",
- "|This film did well at the box office, and the producers o...|[0.18097948, -0.03622232, -0.34149718, 0.061557338, -0.06...|\n",
- "|Peter Crawford discovers a comet on a collision course wi...|[-0.27548054, 0.196654, -0.24626413, -0.39380816, -0.5501...|\n",
- "|This tale of the upper-classes getting their come-uppance...|[0.24201547, 0.011018356, -0.080340266, 0.31388673, -0.28...|\n",
- "|Words almost fail me to describe how terrible this Irish ...|[0.055901285, -0.14539501, -0.14005454, -0.038912475, 0.4...|\n",
- "|This was the most uninteresting horror flick I have seen ...|[0.27159664, -0.012541974, -0.31898177, 0.058205508, 0.56...|\n",
- "| Heart of Darkness was terrible|[0.1593065, 0.36501122, 0.10715093, 0.76344764, 0.2555183...|\n",
- "|I saw this movie when it was first released in Pittsburgh Pa|[-0.34647614, 0.115615666, -0.18874267, 0.36590436, -0.06...|\n",
- "|It was funny because the whole thing was so unrealistic, ...|[0.09473594, -0.43785918, 0.14436111, 0.0045353747, -0.08...|\n",
- "|Watching this movie, you just have to ask: What were they...|[0.43020695, -0.09714467, 0.1356213, 0.23126744, -0.03908...|\n",
- "| In a sense, this movie did not even compare to the novel|[0.2838324, -0.018966805, -0.37275136, 0.27034461, 0.2017...|\n",
- "|Poor Jane Austen ought to be glad she's not around to see...|[0.27462235, -0.32494685, 0.48243234, 0.07208571, 0.22470...|\n",
- "| I gave this movie a four-star rating for a few reasons|[0.31143323, -0.09470663, -0.10863629, 0.077851094, -0.15...|\n",
- "|It seems that Dee Snyder ran out of ideas halfway through...|[0.44354546, -0.08122106, -0.15206784, -0.29244298, 0.559...|\n",
- "|Now, let me see if I have this correct, a lunatic serial ...|[0.39831734, 0.15871558, -0.35366735, -0.11643518, -0.137...|\n",
- "|Tommy Lee Jones was the best Woodroe and no one can play ...|[-0.20960264, -0.15760101, -0.30596393, -0.51817703, -0.0...|\n",
- "|First of all, I would like to say that I am a fan of all ...|[0.25831866, -0.26871824, 0.026099348, -0.3459879, -0.180...|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| embedding|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\n",
+ "|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\n",
+ "|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\n",
+ "|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\n",
+ "| This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\n",
+ "|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\n",
+ "|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\n",
+ "|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
How did...|[-0.048036292, 0.050720096, -0.04668727, -0.316...|\n",
+ "|There is not one character on this sitcom with ...|[0.13720773, -0.5963504, 0.30331734, -0.3830607...|\n",
+ "|Tommy Lee Jones was the best Woodroe and no one...|[-0.20960267, -0.15760122, -0.30596405, -0.5181...|\n",
+ "|My wife rented this movie and then conveniently...|[0.46534792, -0.40655977, 0.054217298, -0.03414...|\n",
+ "|This is one of those star-filled over-the-top c...|[0.14433198, -0.016140658, 0.3775344, 0.0659043...|\n",
+ "|This excruciatingly boring and unfunny movie ma...|[0.056464806, 0.01144963, -0.51797307, 0.089813...|\n",
+ "|you will likely be sorely disappointed by this ...|[-0.44146675, -0.17866582, 0.49889183, -0.26819...|\n",
+ "|If I was British, I would be embarrassed by thi...|[0.1191261, -0.15379854, 0.17487673, -0.5123498...|\n",
+ "|One of those movies in which there are no big t...|[-0.016174048, -0.5558219, -0.024818476, 0.1543...|\n",
+ "|This show is like watching someone who is in tr...|[0.033776704, -0.6682203, 0.30547586, -0.581407...|\n",
+ "| Sigh|[-0.119870394, 0.40893683, 0.4174831, -0.010004...|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
"only showing top 20 rows\n",
"\n"
]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
}
],
"source": [
- "embeddings.show(truncate=60)"
+ "embeddings.show(truncate=50)"
]
},
{
"cell_type": "markdown",
- "id": "b730f5a3-f7eb-42aa-8869-881ecd0f5542",
+ "id": "0c9c6535",
"metadata": {},
"source": [
- "### Using Triton Inference Server\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
"\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
+ "\n",
+ ""
]
},
{
- "cell_type": "markdown",
- "id": "5f502a20",
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "772e337e-1098-4c7b-ba81-8cb221a518e2",
"metadata": {},
+ "outputs": [],
"source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) with the compatible versions of Python/Numpy for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n huggingface-torch -c conda-forge python=3.10.0\n",
- "conda activate huggingface-torch\n",
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "69d0c93a-bb0b-46c5-9d28-7b08a2e70964",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_server(ports):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import torch\n",
+ " from sentence_transformers import SentenceTransformer\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
"\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 conda-pack sentencepiece sentence_transformers transformers\n",
+ " print(f\"SERVER: Initializing sentence transformer on worker {TaskContext.get().partitionId()}.\")\n",
+ " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
+ " model = SentenceTransformer(\"paraphrase-MiniLM-L6-v2\", device=device)\n",
+ " print(f\"SERVER: Using {device} device.\")\n",
"\n",
- "conda-pack # huggingface-torch.tar.gz\n",
- "```"
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = np.squeeze(inputs[\"text\"])\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}\")\n",
+ " decoded_sentences = [s.decode(\"utf-8\") for s in sentences]\n",
+ " embeddings = model.encode(decoded_sentences)\n",
+ " return {\n",
+ " \"embeddings\": embeddings,\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"SentenceTransformer\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=object, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"embeddings\", dtype=np.float32, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=64,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "79532110",
+ "metadata": {},
+ "source": [
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "772e337e-1098-4c7b-ba81-8cb221a518e2",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 23,
+ "id": "b20fd862",
+ "metadata": {},
"outputs": [],
"source": [
- "import numpy as np\n",
- "import os\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct\n",
- "from pyspark.sql.types import ArrayType, FloatType"
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bef23176",
+ "metadata": {},
+ "source": [
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
]
},
{
"cell_type": "code",
- "execution_count": 15,
- "id": "69d0c93a-bb0b-46c5-9d28-7b08a2e70964",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 24,
+ "id": "b992802e",
+ "metadata": {},
"outputs": [],
"source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/hf_transformer_torch models\n",
- "\n",
- "# add custom execution environment\n",
- "cp huggingface-torch.tar.gz models"
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "69015ae1",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
]
},
{
"cell_type": "markdown",
- "id": "dd4d7d4b-1a0b-4c5f-bc93-be2a039b6ea0",
- "metadata": {
- "tags": []
- },
+ "id": "32d5e8e9",
+ "metadata": {},
"source": [
- "#### Start Triton Server on each executor"
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
- "execution_count": 16,
- "id": "1654cdc1-4f9a-4fd5-b7ac-6ca4215bde5d",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 26,
+ "id": "648c0b50",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "012b2d60",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using ports [7000, 7001, 7002]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name = \"SentenceTransformer\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "ea38ac6b",
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "[Stage 28:> (0 + 1) / 1]\r"
]
},
{
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2583427\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "huggingface_cache_dir = \"{}/.cache/huggingface\".format(os.path.expanduser('~'))\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
- "\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " environment=[\n",
- " \"TRANSFORMERS_CACHE=/cache\"\n",
- " ],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"512M\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " huggingface_cache_dir: {\"bind\": \"/cache\", \"mode\": \"rw\"}\n",
- " }\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- "\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " ready = False\n",
- " while not ready:\n",
- " try:\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " time.sleep(5)\n",
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1fd19fae",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "00d82bfe",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "url = f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "807dbc45",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
+ " import numpy as np\n",
+ " from pytriton.client import ModelClient\n",
"\n",
- " return [True]\n",
+ " print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
"\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " flattened = np.squeeze(inputs).tolist()\n",
+ " # Encode batch\n",
+ " encoded_batch = [[text.encode(\"utf-8\")] for text in flattened]\n",
+ " encoded_batch_np = np.array(encoded_batch, dtype=np.bytes_)\n",
+ " # Run inference\n",
+ " result_data = client.infer_batch(encoded_batch_np)\n",
+ " return result_data[\"embeddings\"]\n",
+ " \n",
+ " return infer_batch"
]
},
{
"cell_type": "markdown",
- "id": "ee34de5f-89f8-455e-b45e-a557a4ab0f05",
+ "id": "af174106",
"metadata": {},
"source": [
- "#### Run inference"
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 31,
"id": "2969d502-e97b-49d6-bf80-7d177ae867cf",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
- "from functools import partial\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct\n",
- "from pyspark.sql.types import ArrayType, FloatType"
+ "@pandas_udf(\"string\")\n",
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 32,
"id": "c8f1e6d6-6519-49e7-8465-4419547633b8",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/08 00:20:24 WARN CacheManager: Asked to cache already cached data.\n"
+ "25/01/06 18:43:20 WARN CacheManager: Asked to cache already cached data.\n"
]
}
],
"source": [
- "# only use first N examples, since this is slow\n",
- "df = spark.read.parquet(\"imdb_test\").limit(100).cache()"
+ "df = spark.read.parquet(data_path).limit(256).repartition(8)\n",
+ "df = df.select(preprocess(col(\"text\")).alias(\"input\")).cache()"
]
},
{
- "cell_type": "code",
- "execution_count": 19,
- "id": "29b0cc0d-c480-4e4a-bd41-207dc314cba5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
+ "cell_type": "markdown",
+ "id": "cf0ee731",
+ "metadata": {},
"source": [
- "def triton_fn(triton_uri, model_name):\n",
- " import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
- " return predict"
+ "#### Run Inference"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 33,
"id": "9c712b8f-6eb4-4fb8-9f0a-04feef847fea",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
- "encode = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"hf_transformer_torch\"),\n",
+ "encode = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
" return_type=ArrayType(FloatType()),\n",
" input_tensor_shapes=[[1]],\n",
- " batch_size=100)"
+ " batch_size=32)"
]
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 34,
"id": "934c1a1f-b126-45b0-9c15-265236820ad3",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 32:=======> (1 + 7) / 8]\r"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 4.65 ms, sys: 2.85 ms, total: 7.49 ms\n",
- "Wall time: 480 ms\n"
+ "CPU times: user 7.72 ms, sys: 5.27 ms, total: 13 ms\n",
+ "Wall time: 2.32 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
]
}
],
"source": [
"%%time\n",
"# first pass caches model/fn\n",
- "embeddings = df.withColumn(\"encoding\", encode(struct(\"lines\")))\n",
+ "embeddings = df.withColumn(\"embedding\", encode(struct(\"input\")))\n",
"results = embeddings.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 35,
"id": "f84cd3f6-b6a8-4142-859a-91f3c183457b",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1.45 ms, sys: 1.1 ms, total: 2.56 ms\n",
- "Wall time: 384 ms\n"
+ "CPU times: user 7.92 ms, sys: 0 ns, total: 7.92 ms\n",
+ "Wall time: 171 ms\n"
]
}
],
"source": [
"%%time\n",
- "embeddings = df.withColumn(\"encoding\", encode(\"lines\"))\n",
+ "embeddings = df.withColumn(\"embedding\", encode(\"input\"))\n",
"results = embeddings.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 36,
"id": "921a4c01-e296-4406-be90-86f20c8c582d",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 1.63 ms, sys: 1.28 ms, total: 2.91 ms\n",
- "Wall time: 416 ms\n"
+ "CPU times: user 6.04 ms, sys: 554 μs, total: 6.59 ms\n",
+ "Wall time: 225 ms\n"
]
}
],
"source": [
"%%time\n",
- "embeddings = df.withColumn(\"encoding\", encode(col(\"lines\")))\n",
+ "embeddings = df.withColumn(\"embedding\", encode(col(\"input\")))\n",
"results = embeddings.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 37,
"id": "9f67584e-9c4e-474f-b6ea-7811b14d116e",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "| lines| encoding|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
- "|This is so overly clichéd you'll want to switch it off af...|[-0.06755393, -0.1336537, 0.366753, -0.2772312, -0.085145...|\n",
- "| I was very disappointed by this movie|[-0.059038587, 0.1668467, 0.16768396, 0.10940957, 0.18100...|\n",
- "| I think vampire movies (usually) are wicked|[0.025601566, -0.5308643, -0.31913283, -0.013350786, -0.3...|\n",
- "|Though not a complete waste of time, 'Eighteen' really wa...|[0.2099183, 0.5228606, 0.4451728, -0.031682458, -0.411756...|\n",
- "|This film did well at the box office, and the producers o...|[0.1809797, -0.036222238, -0.34149715, 0.06155738, -0.066...|\n",
- "|Peter Crawford discovers a comet on a collision course wi...|[-0.27548066, 0.196654, -0.24626443, -0.3938084, -0.55015...|\n",
- "|This tale of the upper-classes getting their come-uppance...|[0.24201535, 0.011018419, -0.080340445, 0.31388694, -0.28...|\n",
- "|Words almost fail me to describe how terrible this Irish ...|[0.05590127, -0.14539507, -0.14005487, -0.03891221, 0.444...|\n",
- "|This was the most uninteresting horror flick I have seen ...|[0.2715968, -0.012542339, -0.3189819, 0.05820581, 0.56001...|\n",
- "| Heart of Darkness was terrible|[0.15930629, 0.36501077, 0.10715161, 0.7634482, 0.2555183...|\n",
- "|I saw this movie when it was first released in Pittsburgh Pa|[-0.34647676, 0.11561544, -0.18874292, 0.36590466, -0.068...|\n",
- "|It was funny because the whole thing was so unrealistic, ...|[0.09473588, -0.4378593, 0.14436121, 0.0045354995, -0.085...|\n",
- "|Watching this movie, you just have to ask: What were they...|[0.43020678, -0.09714476, 0.13562134, 0.23126753, -0.0390...|\n",
- "| In a sense, this movie did not even compare to the novel|[0.28383228, -0.01896684, -0.37275153, 0.27034503, 0.2017...|\n",
- "|Poor Jane Austen ought to be glad she's not around to see...|[0.27462238, -0.32494652, 0.48243237, 0.07208576, 0.22470...|\n",
- "| I gave this movie a four-star rating for a few reasons|[0.311433, -0.09470633, -0.10863638, 0.07785072, -0.15611...|\n",
- "|It seems that Dee Snyder ran out of ideas halfway through...|[0.44354525, -0.08122053, -0.15206799, -0.29244322, 0.559...|\n",
- "|Now, let me see if I have this correct, a lunatic serial ...|[0.39831725, 0.15871589, -0.35366756, -0.11643555, -0.137...|\n",
- "|Tommy Lee Jones was the best Woodroe and no one can play ...|[-0.20960276, -0.157601, -0.30596414, -0.5181772, -0.0852...|\n",
- "|First of all, I would like to say that I am a fan of all ...|[0.25831848, -0.26871827, 0.026099432, -0.34598774, -0.18...|\n",
- "+------------------------------------------------------------+------------------------------------------------------------+\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "| input| embedding|\n",
+ "+--------------------------------------------------+--------------------------------------------------+\n",
+ "|Doesn't anyone bother to check where this kind ...|[0.118947476, -0.053823642, -0.29726124, 0.0720...|\n",
+ "|There were two things I hated about WASTED : Th...|[0.18953452, 0.11079162, 0.07503566, 0.01050696...|\n",
+ "|I'm rather surprised that anybody found this fi...|[-0.0010759671, -0.14203517, -0.06649738, 0.129...|\n",
+ "|Cultural Vandalism Is the new Hallmark producti...|[0.34815887, -0.2966917, -0.10905265, 0.1051652...|\n",
+ "|I was at Wrestlemania VI in Toronto as a 10 yea...|[0.45902696, 0.019472413, 0.28720972, -0.070724...|\n",
+ "| This movie has been done before|[-0.062292397, -0.025909504, -0.031942524, 0.01...|\n",
+ "|[ as a new resolution for this year 2005, i dec...|[0.3469342, -0.14378615, 0.30223376, -0.1102267...|\n",
+ "|This movie is over hyped!! I am sad to say that...|[0.13230576, -0.06588756, 0.0472389, 0.08353163...|\n",
+ "|This show had a promising start as sort of the ...|[-0.19361982, -0.14412567, 0.15149693, -0.17715...|\n",
+ "|MINOR PLOT SPOILERS AHEAD!!!
\n",
" \n",
"\n",
@@ -2330,19 +2505,19 @@
"9 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n",
"\n",
" preds \n",
- "0 [-5.7614846, -3.52228, -1.1202906, 13.053683, ... \n",
- "1 [-3.1390061, -8.71185, 0.82955813, -4.034869, ... \n",
- "2 [-3.046528, 0.3521706, 0.6788677, 0.72303534, ... \n",
- "3 [-2.401024, -7.6780066, 11.145876, 1.2857256, ... \n",
- "4 [-5.0012593, 3.806796, -0.8154834, -0.9550028,... \n",
- "5 [-5.0425925, -3.4815094, 1.641246, 3.608149, -... \n",
- "6 [-4.288771, 5.0072904, 0.27649477, -0.797148, ... \n",
- "7 [-2.2032878, -1.6879876, -5.874276, -0.5945335... \n",
- "8 [1.1337761, -3.1751056, -2.5246286, -5.028277,... \n",
- "9 [-0.92484117, -2.4703276, -5.023897, 1.46669, ... "
+ "0 [-4.826194, -1.8435744, 0.627148, 10.832781, -... \n",
+ "1 [-2.4269826, -5.924154, 1.9833497, -1.7571343,... \n",
+ "2 [-2.0138872, 0.8769828, 1.7760125, 1.6830662, ... \n",
+ "3 [-0.9644477, -4.9092045, 11.911339, 2.1581159,... \n",
+ "4 [-3.796785, 3.3146381, -0.6682853, -0.72531927... \n",
+ "5 [-3.6937218, -2.0567024, 2.735378, 4.155004, -... \n",
+ "6 [-2.782737, 4.3289156, 0.7818791, -0.45041704,... \n",
+ "7 [-1.6946642, -1.3310279, -3.7396994, 0.6101702... \n",
+ "8 [0.605576, -2.0545676, -0.44554028, -3.4689393... \n",
+ "9 [-1.876519, -1.993192, -3.1747675, 1.9569131, ... "
]
},
- "execution_count": 60,
+ "execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
@@ -2355,13 +2530,9 @@
},
{
"cell_type": "code",
- "execution_count": 61,
+ "execution_count": 65,
"id": "79d90a26",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
@@ -2370,13 +2541,9 @@
},
{
"cell_type": "code",
- "execution_count": 62,
+ "execution_count": 66,
"id": "4ca495f5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
"sample = preds.iloc[0]\n",
@@ -2388,13 +2555,9 @@
},
{
"cell_type": "code",
- "execution_count": 63,
+ "execution_count": 67,
"id": "a5d10903",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"data": {
@@ -2426,14 +2589,17 @@
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": 68,
"id": "9c9fd967-5cd9-4265-add9-db5c1ccf9893",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -2447,31 +2613,39 @@
"[True]"
]
},
- "execution_count": 64,
+ "execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
+ " \n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 65,
+ "execution_count": null,
"id": "f612dc0b-538f-4ecf-81f7-ef6b58c493ab",
"metadata": {},
"outputs": [],
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras-metadata_tf.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras-metadata_tf.ipynb
deleted file mode 100644
index 007f6d8ae..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras-metadata_tf.ipynb
+++ /dev/null
@@ -1,1259 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "id": "8e6810cc-5982-4293-bfbd-c91ef0aca204",
- "metadata": {},
- "source": [
- "# Distributed model inference using TensorFlow Keras\n",
- "From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "858e3a8d",
- "metadata": {},
- "source": [
- "### Using TensorFlow\n",
- "Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) \n",
- "This notebook does not demonstrate inference with TensorRT, as [TF-TRT](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#tensorrt-10) does not yet support `tf=2.17.0`. See the `pytorch` notebooks for TensorRT demos."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "cf329ac8-0763-44bc-b0f6-b634b7dc480e",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-10-03 17:41:30.112764: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
- "2024-10-03 17:41:30.119504: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
- "2024-10-03 17:41:30.126948: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
- "2024-10-03 17:41:30.129111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
- "2024-10-03 17:41:30.134946: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
- "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2024-10-03 17:41:30.497048: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
- ]
- }
- ],
- "source": [
- "import os\n",
- "import shutil\n",
- "import subprocess\n",
- "import time\n",
- "import pandas as pd\n",
- "from PIL import Image\n",
- "import numpy as np\n",
- "import uuid\n",
- " \n",
- "import tensorflow as tf\n",
- "from tensorflow.keras.applications.resnet50 import ResNet50\n",
- " \n",
- "from pyspark.sql.functions import col, pandas_udf, PandasUDFType\n",
- "from pyspark.sql import SparkSession\n",
- "from pyspark import SparkConf"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "44d72768",
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
- "\n",
- "conf = SparkConf()\n",
- "if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
- "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
- "sc = spark.sparkContext"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "833e36bc",
- "metadata": {},
- "outputs": [],
- "source": [
- "# Enable GPU memory growth\n",
- "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
- "if gpus:\n",
- " try:\n",
- " for gpu in gpus:\n",
- " tf.config.experimental.set_memory_growth(gpu, True)\n",
- " except RuntimeError as e:\n",
- " print(e)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "950b0470-a21e-4778-a80e-b8f6ef792dff",
- "metadata": {},
- "outputs": [],
- "source": [
- "file_name = \"image_data.parquet\"\n",
- "output_file_path = \"predictions\""
- ]
- },
- {
- "cell_type": "markdown",
- "id": "968d08a7-66b9-444f-b362-d8df692aef1c",
- "metadata": {},
- "source": [
- "### Prepare trained model and data for inference"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "da083168-137f-492c-8769-d8f1e2111756",
- "metadata": {},
- "source": [
- "Load the ResNet-50 Model and broadcast the weights."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "2ddc715a-cdbc-4c49-93e9-58c9d88511da",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2024-10-03 17:41:32.482802: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 45311 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels.h5\n",
- "\u001b[1m102967424/102967424\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 0us/step\n"
- ]
- }
- ],
- "source": [
- "model = ResNet50()\n",
- "bc_model_weights = sc.broadcast(model.get_weights())"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "77dddfa3-e8df-4e8e-8251-64457f1ebf80",
- "metadata": {},
- "source": [
- "Load the data and save the datasets to one Parquet file."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "c0738bec-97d4-4946-8c49-5e6d07ff1afc",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\n",
- "\u001b[1m228813984/228813984\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 0us/step\n"
- ]
- }
- ],
- "source": [
- "import pathlib\n",
- "dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n",
- "data_dir = tf.keras.utils.get_file(origin=dataset_url,\n",
- " fname='flower_photos',\n",
- " untar=True)\n",
- "data_dir = pathlib.Path(data_dir)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "014644f4-2a45-4474-8afb-0daf90043253",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "3670\n"
- ]
- }
- ],
- "source": [
- "image_count = len(list(data_dir.glob('*/*.jpg')))\n",
- "print(image_count)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "d54f470a-d308-4426-8ed0-33f95155bb4f",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "2048"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import os\n",
- "files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']\n",
- "files = files[:2048]\n",
- "len(files)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "fd883dc0-4846-4411-a4d6-4f5f252ac707",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "/home/rishic/.keras/datasets/flower_photos\n"
- ]
- }
- ],
- "source": [
- "print(data_dir)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "64f94ee0-f1ea-47f6-a77e-be8da5d1b87a",
- "metadata": {},
- "outputs": [],
- "source": [
- "image_data = []\n",
- "for file in files:\n",
- " img = Image.open(file)\n",
- " img = img.resize([224, 224])\n",
- " data = np.asarray(img, dtype=\"float32\").reshape([224*224*3])\n",
- "\n",
- " image_data.append({\"data\": data})\n",
- "\n",
- "pandas_df = pd.DataFrame(image_data, columns=['data'])\n",
- "pandas_df.to_parquet(file_name)\n",
- "# os.makedirs(dbfs_file_path)\n",
- "# shutil.copyfile(file_name, dbfs_file_path+file_name)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "f2414b0f-58f2-4e4a-9d09-8ea95b38d413",
- "metadata": {},
- "source": [
- "### Save Model\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "670328e3-7274-4d78-b315-487750166a3f",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "0"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "subprocess.call(\"rm -rf resnet50_model\".split())\n",
- "model.export(\"resnet50_model\", verbose=0)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "b827ad56-1af0-41b7-be68-94bd203a2a70",
- "metadata": {},
- "source": [
- "### Load the data into Spark DataFrames"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "8ddc22d0-b88a-4906-bd47-bf247e34feeb",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "2048\n"
- ]
- }
- ],
- "source": [
- "from pyspark.sql.types import *\n",
- "df = spark.read.parquet(file_name)\n",
- "print(df.count())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "id": "c7adf1d9-1fa7-4456-ae32-cf7d1d43bfd3",
- "metadata": {},
- "outputs": [],
- "source": [
- "# spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1024\")\n",
- "spark.conf.set(\"spark.sql.parquet.columnarReaderBatchSize\", \"1024\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "97173c07-a96e-4262-b60f-82865b997e99",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "assert len(df.head()) > 0, \"`df` should not be empty\" # This line will fail if the vectorized reader runs out of memory"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "865929b0-b016-4de4-996d-7f16176cf49c",
- "metadata": {
- "tags": []
- },
- "source": [
- "### Model inference via pandas UDF"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "a67b3128-13c1-44f1-a0c0-7cf7a836fee3",
- "metadata": {},
- "outputs": [],
- "source": [
- "def parse_image(image_data):\n",
- " image = tf.image.convert_image_dtype(\n",
- " image_data, dtype=tf.float32) * (2. / 255) - 1\n",
- " image = tf.reshape(image, [224, 224, 3])\n",
- " return image"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "id": "7b33185f-6d1e-4ca9-9757-fdc3d736496b",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/rishic/anaconda3/envs/spark-dl-tf/lib/python3.11/site-packages/pyspark/sql/pandas/functions.py:407: UserWarning: In Python 3.6+ and Spark 3.0+, it is preferred to specify type hints for pandas UDF instead of specifying pandas UDF type which will be deprecated in the future releases. See SPARK-28264 for more details.\n",
- " warnings.warn(\n"
- ]
- }
- ],
- "source": [
- "@pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR_ITER)\n",
- "def predict_batch_udf(image_batch_iter):\n",
- "\n",
- " # Enable GPU memory growth to avoid CUDA OOM\n",
- " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
- " if gpus:\n",
- " try:\n",
- " for gpu in gpus:\n",
- " tf.config.experimental.set_memory_growth(gpu, True)\n",
- " except RuntimeError as e:\n",
- " print(e)\n",
- "\n",
- " batch_size = 64\n",
- " model = ResNet50(weights=None)\n",
- " model.set_weights(bc_model_weights.value)\n",
- " for image_batch in image_batch_iter:\n",
- " images = np.vstack(image_batch)\n",
- " dataset = tf.data.Dataset.from_tensor_slices(images)\n",
- " dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(\n",
- " 5000).batch(batch_size)\n",
- " preds = model.predict(dataset)\n",
- " yield pd.Series(list(preds))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "id": "ad8c05da-db38-45ef-81d0-1f862f575ced",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| prediction|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "|[1.2938889E-4, 2.4666305E-4, 6.765791E-5, 1.2263245E-4, 5.7486624E-5, 3.9616702E-4, 7.0566134E-6, 4.1722178E-5, 1.225...|\n",
- "|[4.4501914E-5, 3.5403698E-4, 4.6702033E-5, 8.102543E-5, 3.1704556E-5, 1.9194305E-4, 7.905952E-6, 1.3744082E-4, 1.9563...|\n",
- "|[1.05672516E-4, 2.2686279E-4, 3.0055395E-5, 6.523785E-5, 2.352077E-5, 3.7122983E-4, 3.3315896E-6, 2.2584054E-5, 9.775...|\n",
- "|[2.0331638E-5, 2.2746396E-4, 7.828012E-5, 6.986782E-5, 4.705316E-5, 9.80732E-5, 5.561918E-6, 2.3519044E-4, 1.3803913E...|\n",
- "|[1.130241E-4, 2.3187004E-4, 5.296914E-5, 1.0871329E-4, 4.027478E-5, 3.7183522E-4, 5.5931855E-6, 3.4792112E-5, 1.14155...|\n",
- "|[9.094467E-5, 2.06384E-4, 4.514821E-5, 7.665891E-5, 3.2262324E-5, 3.3875552E-4, 3.831814E-6, 4.1848412E-5, 9.94389E-6...|\n",
- "|[1.07847634E-4, 3.7848807E-4, 7.660533E-5, 1.2446754E-4, 4.7595917E-5, 3.333814E-4, 1.0669675E-5, 9.133265E-5, 1.8015...|\n",
- "|[2.2261223E-5, 2.734666E-4, 3.8122747E-5, 6.2266954E-5, 1.7935155E-5, 1.7268128E-4, 6.034271E-6, 1.06450585E-4, 1.789...|\n",
- "|[1.1065645E-4, 2.900581E-4, 4.2585547E-5, 1.074203E-4, 3.052314E-5, 4.794604E-4, 6.4872897E-6, 3.646897E-5, 1.3717402...|\n",
- "|[9.673917E-5, 2.058331E-4, 7.4652424E-5, 1.1323769E-4, 4.6106186E-5, 2.8604185E-4, 5.62365E-6, 5.471466E-5, 9.664386E...|\n",
- "|[7.411196E-5, 3.291524E-4, 1.3454164E-4, 1.7738447E-4, 8.467504E-5, 2.2466244E-4, 1.3621126E-5, 1.1778668E-4, 1.83372...|\n",
- "|[8.721524E-5, 2.7338538E-4, 3.5964815E-5, 7.792533E-5, 2.3559302E-5, 3.6789547E-4, 3.5665628E-6, 3.648153E-5, 1.07589...|\n",
- "|[9.723709E-5, 2.7619812E-4, 5.7464153E-5, 1.10104906E-4, 3.8317143E-5, 3.490506E-4, 6.1553183E-6, 4.413095E-5, 1.1236...|\n",
- "|[6.940235E-5, 2.5377885E-4, 5.057188E-5, 1.1485363E-4, 3.0059196E-5, 2.7862669E-4, 5.024019E-6, 5.1511077E-5, 1.16149...|\n",
- "|[4.2095784E-5, 2.4891715E-4, 1.236292E-4, 1.4306813E-4, 7.3354306E-5, 1.6047148E-4, 7.958807E-6, 1.3556339E-4, 1.4698...|\n",
- "|[2.7327887E-5, 3.8553146E-4, 1.2939748E-4, 1.5762268E-4, 7.307493E-5, 8.5530424E-5, 1.2648808E-5, 1.9154618E-4, 2.307...|\n",
- "|[3.036101E-5, 3.5572305E-4, 1.600718E-4, 2.1437313E-4, 8.063033E-5, 1.02061334E-4, 1.3876456E-5, 1.561292E-4, 1.63637...|\n",
- "|[3.3109587E-5, 2.8182982E-4, 1.7998899E-4, 2.0246049E-4, 1.3720036E-4, 1.01000114E-4, 3.427488E-5, 3.887249E-4, 3.189...|\n",
- "|[4.549448E-5, 2.8782588E-4, 2.3703449E-4, 2.448979E-4, 1.20997625E-4, 1.3744453E-4, 1.62803E-5, 2.2094708E-4, 1.56962...|\n",
- "|[1.2242574E-4, 2.8095162E-4, 6.332559E-5, 1.0209269E-4, 4.335324E-5, 3.906304E-4, 8.205706E-6, 6.202823E-5, 1.5312888...|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n",
- "CPU times: user 10.8 ms, sys: 4.93 ms, total: 15.7 ms\n",
- "Wall time: 9.25 s\n"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "predictions_df = df.select(predict_batch_udf(col(\"data\")).alias(\"prediction\"))\n",
- "predictions_df.show(truncate=120)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "id": "40799f8e-443e-40ca-919b-391f901cb3f4",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 8:===================================================> (7 + 1) / 8]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 9.96 ms, sys: 3.32 ms, total: 13.3 ms\n",
- "Wall time: 14 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "predictions_df.write.mode(\"overwrite\").parquet(output_file_path)"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "16726357-65d8-4d3d-aea1-6800101741cc",
- "metadata": {
- "tags": []
- },
- "source": [
- "### Model inference using Spark DL API"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "id": "e6af27b2-ddc0-42ee-94cc-9ba5ffee6868",
- "metadata": {},
- "outputs": [],
- "source": [
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import struct, col\n",
- "from pyspark.sql.types import ArrayType, FloatType"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "id": "dda88b46-6300-4bf7-bc10-7403f4fbbf92",
- "metadata": {},
- "outputs": [],
- "source": [
- "def predict_batch_fn():\n",
- " import tensorflow as tf\n",
- " from tensorflow.keras.applications.resnet50 import ResNet50\n",
- "\n",
- " # Enable GPU memory growth\n",
- " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
- " if gpus:\n",
- " try:\n",
- " for gpu in gpus:\n",
- " tf.config.experimental.set_memory_growth(gpu, True)\n",
- " except RuntimeError as e:\n",
- " print(e)\n",
- "\n",
- " model = ResNet50()\n",
- " def predict(inputs):\n",
- " inputs = inputs * (2. / 255) - 1\n",
- " return model.predict(inputs)\n",
- " return predict"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "id": "cff0e851-563d-40b6-9d05-509c22b3b7f9",
- "metadata": {},
- "outputs": [],
- "source": [
- "classify = predict_batch_udf(predict_batch_fn,\n",
- " input_tensor_shapes=[[224, 224, 3]],\n",
- " return_type=ArrayType(FloatType()),\n",
- " batch_size=50)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "id": "f733c38b-867d-48c1-b9a6-74a931561896",
- "metadata": {},
- "outputs": [],
- "source": [
- "# spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1024\")\n",
- "spark.conf.set(\"spark.sql.parquet.columnarReaderBatchSize\", \"1024\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "id": "aa7c156f-e2b3-4837-9427-ccf3a5720412",
- "metadata": {},
- "outputs": [],
- "source": [
- "df = spark.read.parquet(\"image_data.parquet\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "id": "80bc50ad-eaf5-4fce-a354-5e17d65e2da5",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 11:===========================================> (3 + 1) / 4]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| prediction|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "|[1.296447E-4, 2.465122E-4, 6.7463385E-5, 1.2231144E-4, 5.731739E-5, 3.9644213E-4, 7.0297688E-6, 4.1668914E-5, 1.22212...|\n",
- "|[4.4481887E-5, 3.526653E-4, 4.683818E-5, 8.1168495E-5, 3.178377E-5, 1.9188467E-4, 7.885617E-6, 1.3758946E-4, 1.956621...|\n",
- "|[1.05946536E-4, 2.2744355E-4, 3.0219735E-5, 6.548672E-5, 2.3649674E-5, 3.7177472E-4, 3.353236E-6, 2.271976E-5, 9.8115...|\n",
- "|[2.0392703E-5, 2.2817637E-4, 7.840744E-5, 6.9875685E-5, 4.702542E-5, 9.8244425E-5, 5.5829764E-6, 2.3530141E-4, 1.3836...|\n",
- "|[1.1312391E-4, 2.31244E-4, 5.279228E-5, 1.0859927E-4, 4.0202678E-5, 3.721753E-4, 5.563934E-6, 3.4674114E-5, 1.1389492...|\n",
- "|[9.126345E-5, 2.0679034E-4, 4.5165678E-5, 7.679106E-5, 3.234611E-5, 3.3994843E-4, 3.84E-6, 4.1930372E-5, 9.949454E-6,...|\n",
- "|[1.07930486E-4, 3.7741542E-4, 7.613175E-5, 1.2414041E-4, 4.7409427E-5, 3.332554E-4, 1.05853915E-5, 9.1026224E-5, 1.79...|\n",
- "|[2.2216762E-5, 2.7354853E-4, 3.8192928E-5, 6.2340725E-5, 1.7952003E-5, 1.7253387E-4, 6.020507E-6, 1.0669143E-4, 1.786...|\n",
- "|[1.10480236E-4, 2.89734E-4, 4.239379E-5, 1.0727814E-4, 3.047985E-5, 4.7992737E-4, 6.4530495E-6, 3.6428817E-5, 1.36967...|\n",
- "|[9.6864875E-5, 2.0573521E-4, 7.4498465E-5, 1.1323085E-4, 4.6088306E-5, 2.8680824E-4, 5.604823E-6, 5.461046E-5, 9.6629...|\n",
- "|[7.4198484E-5, 3.2886668E-4, 1.3441108E-4, 1.7755068E-4, 8.469927E-5, 2.2534095E-4, 1.3617541E-5, 1.1781904E-4, 1.833...|\n",
- "|[8.7561886E-5, 2.7312653E-4, 3.5959012E-5, 7.7946424E-5, 2.3565723E-5, 3.6881721E-4, 3.5630535E-6, 3.642736E-5, 1.074...|\n",
- "|[9.743975E-5, 2.7615853E-4, 5.74148E-5, 1.10329434E-4, 3.83045E-5, 3.500394E-4, 6.167429E-6, 4.4207005E-5, 1.1250093E...|\n",
- "|[6.9320704E-5, 2.53287E-4, 5.0612853E-5, 1.14936556E-4, 3.0210098E-5, 2.7870742E-4, 5.031114E-6, 5.169024E-5, 1.16021...|\n",
- "|[4.2203726E-5, 2.4911022E-4, 1.2378568E-4, 1.4274308E-4, 7.32259E-5, 1.6058519E-4, 7.9425035E-6, 1.3519496E-4, 1.4662...|\n",
- "|[2.7190901E-5, 3.8381666E-4, 1.2918573E-4, 1.570463E-4, 7.310112E-5, 8.554618E-5, 1.2614603E-5, 1.9213595E-4, 2.30354...|\n",
- "|[3.0573912E-5, 3.5561546E-4, 1.5945674E-4, 2.1361349E-4, 8.046549E-5, 1.0269262E-4, 1.3862439E-5, 1.5622783E-4, 1.638...|\n",
- "|[3.3117096E-5, 2.8073433E-4, 1.7961214E-4, 2.020287E-4, 1.3662946E-4, 1.0117796E-4, 3.4090703E-5, 3.8897162E-4, 3.181...|\n",
- "|[4.5728237E-5, 2.8880237E-4, 2.3783019E-4, 2.4589908E-4, 1.2160292E-4, 1.3812551E-4, 1.6343482E-5, 2.2073709E-4, 1.57...|\n",
- "|[1.2280059E-4, 2.806991E-4, 6.3642765E-5, 1.02471764E-4, 4.351664E-5, 3.9150563E-4, 8.235125E-6, 6.211928E-5, 1.53269...|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n",
- "CPU times: user 8.12 ms, sys: 3.38 ms, total: 11.5 ms\n",
- "Wall time: 5.59 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "# first pass caches model/fn\n",
- "predictions = df.select(classify(struct(\"data\")).alias(\"prediction\"))\n",
- "predictions.show(truncate=120)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 27,
- "id": "41cace80-7a4b-4929-8e63-9c83f9745e02",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 13:===========================================> (3 + 1) / 4]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| prediction|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "|[1.293178E-4, 2.4644283E-4, 6.760039E-5, 1.2260793E-4, 5.7431564E-5, 3.9597694E-4, 7.0522524E-6, 4.1717416E-5, 1.2240...|\n",
- "|[4.4487308E-5, 3.5378174E-4, 4.6667028E-5, 8.102564E-5, 3.168566E-5, 1.9189132E-4, 7.903805E-6, 1.3741471E-4, 1.95482...|\n",
- "|[1.0566196E-4, 2.2684377E-4, 3.00564E-5, 6.5251304E-5, 2.3520754E-5, 3.7116173E-4, 3.331476E-6, 2.2584616E-5, 9.77515...|\n",
- "|[2.0337258E-5, 2.2749524E-4, 7.8351426E-5, 6.991163E-5, 4.7081656E-5, 9.8092445E-5, 5.564894E-6, 2.3517481E-4, 1.3805...|\n",
- "|[1.12979564E-4, 2.3172122E-4, 5.2946547E-5, 1.0876398E-4, 4.0259067E-5, 3.7143996E-4, 5.5940513E-6, 3.4814777E-5, 1.1...|\n",
- "|[9.093228E-5, 2.0639994E-4, 4.5151268E-5, 7.666316E-5, 3.2264295E-5, 3.387436E-4, 3.832487E-6, 4.185193E-5, 9.944773E...|\n",
- "|[1.0783461E-4, 3.7850672E-4, 7.660902E-5, 1.2446321E-4, 4.7591406E-5, 3.3328883E-4, 1.067249E-5, 9.131178E-5, 1.80121...|\n",
- "|[2.2258617E-5, 2.7345872E-4, 3.814439E-5, 6.229726E-5, 1.79387E-5, 1.7259057E-4, 6.0371217E-6, 1.0649798E-4, 1.789726...|\n",
- "|[1.1067773E-4, 2.8997674E-4, 4.2570035E-5, 1.0747747E-4, 3.0524247E-5, 4.7921995E-4, 6.489833E-6, 3.6502548E-5, 1.371...|\n",
- "|[9.676251E-5, 2.0588847E-4, 7.467098E-5, 1.1326933E-4, 4.6123736E-5, 2.8609246E-4, 5.627118E-6, 5.4726373E-5, 9.66839...|\n",
- "|[7.4104944E-5, 3.290917E-4, 1.3448784E-4, 1.7742367E-4, 8.463227E-5, 2.2462371E-4, 1.3614881E-5, 1.17794625E-4, 1.833...|\n",
- "|[8.7211796E-5, 2.7337394E-4, 3.5953894E-5, 7.7924225E-5, 2.3554327E-5, 3.67775E-4, 3.5652213E-6, 3.647724E-5, 1.07577...|\n",
- "|[9.7237185E-5, 2.762026E-4, 5.7450008E-5, 1.1019135E-4, 3.831896E-5, 3.4878452E-4, 6.1574788E-6, 4.415526E-5, 1.12374...|\n",
- "|[6.938849E-5, 2.5376282E-4, 5.0565883E-5, 1.14880335E-4, 3.0061366E-5, 2.7866007E-4, 5.024482E-6, 5.152425E-5, 1.1617...|\n",
- "|[4.2096388E-5, 2.4889092E-4, 1.2363133E-4, 1.4304162E-4, 7.337785E-5, 1.6042824E-4, 7.959722E-6, 1.3552785E-4, 1.4693...|\n",
- "|[2.730248E-5, 3.851789E-4, 1.293143E-4, 1.5753493E-4, 7.302161E-5, 8.547956E-5, 1.26348905E-5, 1.9148648E-4, 2.304900...|\n",
- "|[3.0354899E-5, 3.5562844E-4, 1.6008675E-4, 2.1440513E-4, 8.062159E-5, 1.02023136E-4, 1.3876455E-5, 1.5611007E-4, 1.63...|\n",
- "|[3.3083066E-5, 2.8158593E-4, 1.7979987E-4, 2.0232225E-4, 1.3704685E-4, 1.0091762E-4, 3.4243407E-5, 3.8870922E-4, 3.18...|\n",
- "|[4.5485373E-5, 2.878148E-4, 2.3707838E-4, 2.4493985E-4, 1.21028905E-4, 1.3738636E-4, 1.6280053E-5, 2.2104722E-4, 1.56...|\n",
- "|[1.22468E-4, 2.809503E-4, 6.3342835E-5, 1.021957E-4, 4.3373006E-5, 3.905496E-4, 8.212427E-6, 6.2081075E-5, 1.5323925E...|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n",
- "CPU times: user 2.75 ms, sys: 3.03 ms, total: 5.78 ms\n",
- "Wall time: 4.79 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "predictions = df.select(classify(\"data\").alias(\"prediction\"))\n",
- "predictions.show(truncate=120)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "id": "56a2ec8a-de09-4d7c-9666-1b3c76f10657",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 14:==================================================> (7 + 1) / 8]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 10.7 ms, sys: 4.25 ms, total: 14.9 ms\n",
- "Wall time: 16.9 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "predictions = df.select(classify(col(\"data\")).alias(\"prediction\"))\n",
- "predictions.write.mode(\"overwrite\").parquet(output_file_path + \"_1\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "0efa48e9-2eda-4d57-8174-8850e5bca4af",
- "metadata": {},
- "source": [
- "### Using Triton Inference Server\n",
- "\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "id": "2605d134-ef75-4d94-9b16-2c6d85f29bef",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "import os\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct\n",
- "from pyspark.sql.types import ArrayType, FloatType"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "id": "4666e618-8038-4dc5-9be7-793aedbf4500",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "%%bash\n",
- "# copy model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models/resnet50/1\n",
- "cp -r resnet50_model models/resnet50/1/model.savedmodel\n",
- "\n",
- "# add config.pbtxt\n",
- "cp models_config/resnet50/config.pbtxt models/resnet50/config.pbtxt"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "e07f1d6d-334e-4f85-9472-171dda09bae4",
- "metadata": {},
- "source": [
- "#### Start Triton Server on each executor"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 31,
- "id": "8c8c0744-0558-4dac-bbfe-8bdde4b2af2d",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 31,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
- "\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"512M\",\n",
- " volumes={triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"}}\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- "\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " ready = False\n",
- " while not ready:\n",
- " try:\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " time.sleep(5)\n",
- "\n",
- " return [True]\n",
- "\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "8c07365c-0a14-49b3-9bd8-cfb35f48b089",
- "metadata": {},
- "source": [
- "#### Run inference"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "id": "bcd46360-6851-4a9d-8590-c086e001242a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "def triton_fn(triton_uri, model_name):\n",
- " import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " inputs = inputs * (2. / 255) - 1 # add normalization\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
- " return predict"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "id": "9fabcaeb-5a44-42bb-8097-5dbc2d0cee3e",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "from functools import partial\n",
- "\n",
- "classify = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"resnet50\"),\n",
- " input_tensor_shapes=[[224, 224, 3]],\n",
- " return_type=ArrayType(FloatType()),\n",
- " batch_size=50)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "id": "b17f33c8-a0f0-4bce-91f8-5838ba9b12a7",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "# spark.conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1024\")\n",
- "spark.conf.set(\"spark.sql.parquet.columnarReaderBatchSize\", \"1024\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "id": "8e5b9e99-a1cf-43d3-a795-c7271a917057",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "df = spark.read.parquet(\"image_data.parquet\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 36,
- "id": "e595473d-1a5d-46a6-a6ba-89d2ea903de9",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 18:===========================================> (3 + 1) / 4]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| prediction|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "|[1.2838157E-4, 2.442499E-4, 6.756602E-5, 1.223822E-4, 5.718728E-5, 3.9370774E-4, 6.9826538E-6, 4.180329E-5, 1.21474E-...|\n",
- "|[4.3975022E-5, 3.5182733E-4, 4.6756446E-5, 8.051952E-5, 3.157192E-5, 1.8915786E-4, 7.8848925E-6, 1.3820908E-4, 1.9617...|\n",
- "|[1.0483801E-4, 2.2482511E-4, 2.9800098E-5, 6.471683E-5, 2.3306355E-5, 3.6853546E-4, 3.2802545E-6, 2.2436941E-5, 9.655...|\n",
- "|[2.0184121E-5, 2.2646098E-4, 7.754879E-5, 6.9126E-5, 4.6796213E-5, 9.757494E-5, 5.5280707E-6, 2.3486002E-4, 1.3758638...|\n",
- "|[1.1207414E-4, 2.3036542E-4, 5.2748997E-5, 1.0843094E-4, 3.9970357E-5, 3.692824E-4, 5.5317682E-6, 3.467135E-5, 1.1321...|\n",
- "|[9.028466E-5, 2.0533502E-4, 4.5085282E-5, 7.65107E-5, 3.217092E-5, 3.3741904E-4, 3.8024857E-6, 4.1927728E-5, 9.920564...|\n",
- "|[1.0625615E-4, 3.759827E-4, 7.6174496E-5, 1.2342798E-4, 4.7335903E-5, 3.3091815E-4, 1.0598523E-5, 9.161089E-5, 1.7926...|\n",
- "|[2.2157477E-5, 2.726377E-4, 3.831429E-5, 6.2276886E-5, 1.8050652E-5, 1.7177712E-4, 6.0331595E-6, 1.06755506E-4, 1.790...|\n",
- "|[1.0993216E-4, 2.8824335E-4, 4.2543048E-5, 1.06903855E-4, 3.039875E-5, 4.7743318E-4, 6.441006E-6, 3.6423717E-5, 1.361...|\n",
- "|[9.6276366E-5, 2.047977E-4, 7.4698546E-5, 1.128771E-4, 4.6044628E-5, 2.8445767E-4, 5.6014956E-6, 5.475251E-5, 9.63856...|\n",
- "|[7.3160336E-5, 3.2700456E-4, 1.3447899E-4, 1.7689951E-4, 8.4440886E-5, 2.2350134E-4, 1.3515168E-5, 1.1746432E-4, 1.81...|\n",
- "|[8.632592E-5, 2.7143923E-4, 3.583003E-5, 7.763873E-5, 2.3417528E-5, 3.6477615E-4, 3.527159E-6, 3.646688E-5, 1.0721673...|\n",
- "|[9.640316E-5, 2.7391897E-4, 5.7131063E-5, 1.09568326E-4, 3.8045353E-5, 3.472495E-4, 6.057242E-6, 4.3799748E-5, 1.1118...|\n",
- "|[6.912533E-5, 2.5222785E-4, 5.0288483E-5, 1.1415517E-4, 2.9881658E-5, 2.7816373E-4, 4.972507E-6, 5.121496E-5, 1.15293...|\n",
- "|[4.189945E-5, 2.4779947E-4, 1.2303083E-4, 1.4200866E-4, 7.2787174E-5, 1.600041E-4, 7.901948E-6, 1.3503798E-4, 1.46427...|\n",
- "|[2.7033573E-5, 3.8410365E-4, 1.2880778E-4, 1.5630701E-4, 7.2431474E-5, 8.455686E-5, 1.2551222E-5, 1.9146077E-4, 2.293...|\n",
- "|[2.9902518E-5, 3.521676E-4, 1.6034822E-4, 2.1348803E-4, 8.053424E-5, 1.00774814E-4, 1.3777179E-5, 1.5595586E-4, 1.615...|\n",
- "|[3.2834323E-5, 2.8044736E-4, 1.8003663E-4, 2.017913E-4, 1.3718085E-4, 1.0062256E-4, 3.4619785E-5, 3.8973117E-4, 3.187...|\n",
- "|[4.4552748E-5, 2.8623734E-4, 2.3419394E-4, 2.4108509E-4, 1.1926766E-4, 1.3529808E-4, 1.6018543E-5, 2.210266E-4, 1.558...|\n",
- "|[1.2160183E-4, 2.8021698E-4, 6.289166E-5, 1.0147789E-4, 4.3161614E-5, 3.8964444E-4, 8.174407E-6, 6.2043844E-5, 1.5228...|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n",
- "CPU times: user 4.79 ms, sys: 1.93 ms, total: 6.72 ms\n",
- "Wall time: 3.06 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "# first pass caches model/fn\n",
- "predictions = df.select(classify(struct(\"data\")).alias(\"prediction\"))\n",
- "predictions.show(truncate=120)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "id": "5f66d468-e0b1-4589-8606-b3848063a823",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 20:===========================================> (3 + 1) / 4]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "| prediction|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "|[1.2838157E-4, 2.442499E-4, 6.756602E-5, 1.223822E-4, 5.718728E-5, 3.9370774E-4, 6.9826538E-6, 4.180329E-5, 1.21474E-...|\n",
- "|[4.3975022E-5, 3.5182733E-4, 4.6756446E-5, 8.051952E-5, 3.157192E-5, 1.8915786E-4, 7.8848925E-6, 1.3820908E-4, 1.9617...|\n",
- "|[1.0483801E-4, 2.2482511E-4, 2.9800098E-5, 6.471683E-5, 2.3306355E-5, 3.6853546E-4, 3.2802545E-6, 2.2436941E-5, 9.655...|\n",
- "|[2.0184121E-5, 2.2646098E-4, 7.754879E-5, 6.9126E-5, 4.6796213E-5, 9.757494E-5, 5.5280707E-6, 2.3486002E-4, 1.3758638...|\n",
- "|[1.1207414E-4, 2.3036542E-4, 5.2748997E-5, 1.0843094E-4, 3.9970357E-5, 3.692824E-4, 5.5317682E-6, 3.467135E-5, 1.1321...|\n",
- "|[9.028466E-5, 2.0533502E-4, 4.5085282E-5, 7.65107E-5, 3.217092E-5, 3.3741904E-4, 3.8024857E-6, 4.1927728E-5, 9.920564...|\n",
- "|[1.0625615E-4, 3.759827E-4, 7.6174496E-5, 1.2342798E-4, 4.7335903E-5, 3.3091815E-4, 1.0598523E-5, 9.161089E-5, 1.7926...|\n",
- "|[2.2157477E-5, 2.726377E-4, 3.831429E-5, 6.2276886E-5, 1.8050652E-5, 1.7177712E-4, 6.0331595E-6, 1.06755506E-4, 1.790...|\n",
- "|[1.0993216E-4, 2.8824335E-4, 4.2543048E-5, 1.06903855E-4, 3.039875E-5, 4.7743318E-4, 6.441006E-6, 3.6423717E-5, 1.361...|\n",
- "|[9.6276366E-5, 2.047977E-4, 7.4698546E-5, 1.128771E-4, 4.6044628E-5, 2.8445767E-4, 5.6014956E-6, 5.475251E-5, 9.63856...|\n",
- "|[7.3160336E-5, 3.2700456E-4, 1.3447899E-4, 1.7689951E-4, 8.4440886E-5, 2.2350134E-4, 1.3515168E-5, 1.1746432E-4, 1.81...|\n",
- "|[8.632592E-5, 2.7143923E-4, 3.583003E-5, 7.763873E-5, 2.3417528E-5, 3.6477615E-4, 3.527159E-6, 3.646688E-5, 1.0721673...|\n",
- "|[9.640316E-5, 2.7391897E-4, 5.7131063E-5, 1.09568326E-4, 3.8045353E-5, 3.472495E-4, 6.057242E-6, 4.3799748E-5, 1.1118...|\n",
- "|[6.912533E-5, 2.5222785E-4, 5.0288483E-5, 1.1415517E-4, 2.9881658E-5, 2.7816373E-4, 4.972507E-6, 5.121496E-5, 1.15293...|\n",
- "|[4.189945E-5, 2.4779947E-4, 1.2303083E-4, 1.4200866E-4, 7.2787174E-5, 1.600041E-4, 7.901948E-6, 1.3503798E-4, 1.46427...|\n",
- "|[2.7033573E-5, 3.8410365E-4, 1.2880778E-4, 1.5630701E-4, 7.2431474E-5, 8.455686E-5, 1.2551222E-5, 1.9146077E-4, 2.293...|\n",
- "|[2.9902518E-5, 3.521676E-4, 1.6034822E-4, 2.1348803E-4, 8.053424E-5, 1.00774814E-4, 1.3777179E-5, 1.5595586E-4, 1.615...|\n",
- "|[3.2834323E-5, 2.8044736E-4, 1.8003663E-4, 2.017913E-4, 1.3718085E-4, 1.0062256E-4, 3.4619785E-5, 3.8973117E-4, 3.187...|\n",
- "|[4.4552748E-5, 2.8623734E-4, 2.3419394E-4, 2.4108509E-4, 1.1926766E-4, 1.3529808E-4, 1.6018543E-5, 2.210266E-4, 1.558...|\n",
- "|[1.2160183E-4, 2.8021698E-4, 6.289166E-5, 1.0147789E-4, 4.3161614E-5, 3.8964444E-4, 8.174407E-6, 6.2043844E-5, 1.5228...|\n",
- "+------------------------------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n",
- "CPU times: user 3.16 ms, sys: 3.36 ms, total: 6.52 ms\n",
- "Wall time: 2.24 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "predictions = df.select(classify(\"data\").alias(\"prediction\"))\n",
- "predictions.show(truncate=120)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 38,
- "id": "632c4c3a-fa52-4c3d-b71e-7526286e353a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 21:==================================================> (7 + 1) / 8]\r"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "CPU times: user 7.57 ms, sys: 5.2 ms, total: 12.8 ms\n",
- "Wall time: 13.3 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- }
- ],
- "source": [
- "%%time\n",
- "predictions = df.select(classify(col(\"data\")).alias(\"prediction\"))\n",
- "predictions.write.mode(\"overwrite\").parquet(output_file_path + \"_2\")"
- ]
- },
- {
- "cell_type": "markdown",
- "id": "4dc06b7e-f750-40b5-9208-a035db11d937",
- "metadata": {
- "tags": []
- },
- "source": [
- "#### Stop Triton Server on each executor"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 39,
- "id": "bbfcaa51-3b9f-43ff-a4a8-4b46766115b8",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 39,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
- "\n",
- " return [True]\n",
- "\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 40,
- "id": "0d88639b-d934-4eb4-ae2f-cc13b9b10456",
- "metadata": {},
- "outputs": [],
- "source": [
- "spark.stop()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "df8cc28a-34d7-479c-be7e-9a380d39e25e",
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "spark-dl-tf",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/feature_columns_tf.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_preprocessing_tf.ipynb
similarity index 57%
rename from examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/feature_columns_tf.ipynb
rename to examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_preprocessing_tf.ipynb
index 2ff37b6c6..f798b05d3 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/feature_columns_tf.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_preprocessing_tf.ipynb
@@ -5,9 +5,13 @@
"id": "7fcc021a",
"metadata": {},
"source": [
+ "\n",
+ "\n",
"# Pyspark TensorFlow Inference\n",
"\n",
- "## Feature Columns\n",
+ "### Classification using Keras Preprocessing Layers\n",
+ "\n",
+ "In this notebook, we demonstrate distributed inference using Keras preprocessing layers to classify structured data. \n",
"From: https://www.tensorflow.org/tutorials/structured_data/preprocessing_layers"
]
},
@@ -16,9 +20,7 @@
"id": "35203476",
"metadata": {},
"source": [
- "### Using TensorFlow\n",
- "Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) \n",
- "This notebook does not demonstrate inference with TensorRT, as [TF-TRT](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#tensorrt-10) does not yet support `tf=2.17.0`. See the `pytorch` notebooks for TensorRT demos."
+ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) "
]
},
{
@@ -31,17 +33,19 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-24 16:04:17.711230: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
- "2024-10-24 16:04:17.719701: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
- "2024-10-24 16:04:17.728758: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
- "2024-10-24 16:04:17.731459: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
- "2024-10-24 16:04:17.738797: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "2025-01-06 21:14:44.313899: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2025-01-06 21:14:44.321948: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "2025-01-06 21:14:44.330652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "2025-01-06 21:14:44.333277: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2025-01-06 21:14:44.340133: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2024-10-24 16:04:18.115892: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ "2025-01-06 21:14:44.761403: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"source": [
+ "import os\n",
+ "import shutil\n",
"import numpy as np\n",
"import pandas as pd\n",
"import tensorflow as tf\n",
@@ -52,6 +56,16 @@
{
"cell_type": "code",
"execution_count": 2,
+ "id": "0d586fb8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "os.mkdir('models') if not os.path.exists('models') else None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
"id": "9fa3e1b7-58cd-45f9-9fee-85f25a31c3c6",
"metadata": {},
"outputs": [
@@ -76,12 +90,31 @@
" print(e)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "b2402b9a",
+ "metadata": {},
+ "source": [
+ "#### Download dataset\n",
+ "\n",
+ "Download the PetFinder dataset from Kaggle, which where each row describes a pet and the goal is to predict adoption speed."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"id": "9326b072-a53c-40c4-a6cb-bd4d3d644d03",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/petfinder-mini.zip\n",
+ "\u001b[1m1668792/1668792\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 0us/step\n"
+ ]
+ }
+ ],
"source": [
"import pathlib\n",
"import os\n",
@@ -100,7 +133,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 6,
"id": "e98480ef-d13d-44c0-a227-e9a22f9bf2b0",
"metadata": {},
"outputs": [
@@ -260,7 +293,7 @@
"4 This handsome yet cute boy is up for adoption.... 3 2 "
]
},
- "execution_count": 11,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -269,9 +302,17 @@
"dataframe.head()"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "27d844f1",
+ "metadata": {},
+ "source": [
+ "### Prepare dataset"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"id": "e8efce25-a835-4cbd-b8a2-1418ba2c1d31",
"metadata": {},
"outputs": [],
@@ -286,7 +327,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"id": "00d403cf-9ae7-4780-9fac-13d920d8b395",
"metadata": {},
"outputs": [
@@ -305,7 +346,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 9,
"id": "4206a56e-5403-42a9-805e-e037044e7995",
"metadata": {},
"outputs": [
@@ -325,9 +366,17 @@
"print(len(test), 'test examples')"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "a7fa64f8",
+ "metadata": {},
+ "source": [
+ "Create an input pipeline which converts each dataset into a tf.data.Dataset with shuffling and batching."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 10,
"id": "499ade5f-ac8a-47ca-a021-071239dfe97d",
"metadata": {},
"outputs": [],
@@ -344,9 +393,17 @@
" return ds"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "96065bed",
+ "metadata": {},
+ "source": [
+ "Check the format of the data returned by the pipeline:"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 11,
"id": "b9ec57c9-080e-4626-9e03-acf309cf3736",
"metadata": {},
"outputs": [
@@ -354,7 +411,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:38:53.526119: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46022 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
+ "2025-01-06 21:14:45.776288: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 42175 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
]
}
],
@@ -363,9 +420,17 @@
"train_ds = df_to_dataset(train, batch_size=batch_size)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "bdc8571c",
+ "metadata": {},
+ "source": [
+ "(Note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963)."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 12,
"id": "dfcbf268-4508-4eb8-abe1-acf1dbb97bd5",
"metadata": {},
"outputs": [
@@ -375,19 +440,19 @@
"text": [
"Every feature: ['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\n",
"A batch of ages: tf.Tensor(\n",
- "[[18]\n",
- " [ 5]\n",
+ "[[ 2]\n",
" [ 2]\n",
- " [ 5]\n",
- " [ 1]], shape=(5, 1), dtype=int64)\n",
- "A batch of targets: tf.Tensor([1 0 1 1 1], shape=(5,), dtype=int64)\n"
+ " [12]\n",
+ " [ 3]\n",
+ " [ 2]], shape=(5, 1), dtype=int64)\n",
+ "A batch of targets: tf.Tensor([1 1 1 1 1], shape=(5,), dtype=int64)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:38:53.588272: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
+ "2025-01-06 21:14:45.853063: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
]
}
],
@@ -398,9 +463,19 @@
"print('A batch of targets:', label_batch )"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "d5a2d10c",
+ "metadata": {},
+ "source": [
+ "### Apply Keras preprocessing layers\n",
+ "\n",
+ "We'll define a normalization layer for numeric features, and a category encoding for categorical features."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 13,
"id": "6c09dc4b-3a2a-44f5-b41c-821ec30b87b1",
"metadata": {},
"outputs": [],
@@ -420,7 +495,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 14,
"id": "59bb91dc-360a-4a89-a9ea-bebc1ddbf1b7",
"metadata": {},
"outputs": [
@@ -428,21 +503,21 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:38:55.015073: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
+ "2025-01-06 21:14:47.344194: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
]
},
{
"data": {
"text/plain": [
""
+ "array([[ 0.43390083],\n",
+ " [ 0.43390083],\n",
+ " [ 2.004037 ],\n",
+ " [ 1.0619552 ],\n",
+ " [-0.822208 ]], dtype=float32)>"
]
},
- "execution_count": 13,
+ "execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@@ -455,7 +530,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 15,
"id": "4623b612-e924-472b-9ef4-c7f14f9f53c5",
"metadata": {},
"outputs": [],
@@ -484,7 +559,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 16,
"id": "0a40e9ee-20a5-4a42-8543-c267f99af55e",
"metadata": {},
"outputs": [
@@ -492,14 +567,14 @@
"data": {
"text/plain": [
""
+ " [0., 0., 1.],\n",
+ " [0., 0., 1.]], dtype=float32)>"
]
},
- "execution_count": 15,
+ "execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@@ -514,7 +589,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 17,
"id": "ff63a5cc-71f4-428e-9299-a8018edc7648",
"metadata": {},
"outputs": [
@@ -522,21 +597,21 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:38:56.454126: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
+ "2025-01-06 21:14:48.950627: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
]
},
{
"data": {
"text/plain": [
""
+ " [0., 0., 1., 0., 0.],\n",
+ " [0., 1., 0., 0., 0.]], dtype=float32)>"
]
},
- "execution_count": 16,
+ "execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
@@ -550,9 +625,19 @@
"test_age_layer(test_age_col)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "afefbcf2",
+ "metadata": {},
+ "source": [
+ "### Preprocess selected features\n",
+ "\n",
+ "Apply the preprocessing utility functions defined earlier. Add all the feature inputs to a list.\n"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 18,
"id": "2b040b0e-d8ca-4cf0-917c-dd9a272e1f0a",
"metadata": {},
"outputs": [],
@@ -565,7 +650,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 19,
"id": "19df498e-4dd1-467a-8741-e1f5e15932a5",
"metadata": {},
"outputs": [],
@@ -584,7 +669,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 20,
"id": "1d12579f-34fb-40b0-a16a-3e13cfea8178",
"metadata": {},
"outputs": [],
@@ -602,7 +687,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 21,
"id": "bff286eb-7ad7-4d3a-8fa4-c729692d1425",
"metadata": {},
"outputs": [
@@ -610,8 +695,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:38:56.758056: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
- "2024-10-03 17:38:57.171981: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
+ "2025-01-06 21:14:49.223773: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n",
+ "2025-01-06 21:14:49.659659: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
]
}
],
@@ -630,9 +715,17 @@
" encoded_features.append(encoded_categorical_col)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "e0dfac0d",
+ "metadata": {},
+ "source": [
+ "### Create, compile, and train model"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 22,
"id": "79247436-32d8-4738-a656-3f288c77001c",
"metadata": {},
"outputs": [],
@@ -647,7 +740,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 23,
"id": "dbc85d3e-6d1e-4167-9516-b1182e880542",
"metadata": {},
"outputs": [],
@@ -660,7 +753,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 24,
"id": "bc9836c8-3c1a-41ad-8833-a946bafcfb00",
"metadata": {},
"outputs": [
@@ -668,41 +761,35 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Epoch 1/10\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.4109 - loss: 0.7333 - val_accuracy: 0.6898 - val_loss: 0.5666\n",
+ "Epoch 1/10\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - accuracy: 0.4366 - loss: 0.6702 - val_accuracy: 0.7062 - val_loss: 0.5727\n",
"Epoch 2/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6423 - loss: 0.5994 - val_accuracy: 0.7210 - val_loss: 0.5484\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.6428 - loss: 0.5924 - val_accuracy: 0.7236 - val_loss: 0.5529\n",
"Epoch 3/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 17ms/step - accuracy: 0.6825 - loss: 0.5728 - val_accuracy: 0.7253 - val_loss: 0.5383\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6820 - loss: 0.5629 - val_accuracy: 0.7418 - val_loss: 0.5404\n",
"Epoch 4/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 17ms/step - accuracy: 0.6796 - loss: 0.5653 - val_accuracy: 0.7331 - val_loss: 0.5314\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6850 - loss: 0.5624 - val_accuracy: 0.7400 - val_loss: 0.5313\n",
"Epoch 5/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - accuracy: 0.6853 - loss: 0.5584 - val_accuracy: 0.7348 - val_loss: 0.5259\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6953 - loss: 0.5633 - val_accuracy: 0.7374 - val_loss: 0.5253\n",
"Epoch 6/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.7120 - loss: 0.5447 - val_accuracy: 0.7418 - val_loss: 0.5218\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.6982 - loss: 0.5516 - val_accuracy: 0.7392 - val_loss: 0.5199\n",
"Epoch 7/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.7068 - loss: 0.5422 - val_accuracy: 0.7435 - val_loss: 0.5189\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 16ms/step - accuracy: 0.7049 - loss: 0.5438 - val_accuracy: 0.7444 - val_loss: 0.5166\n",
"Epoch 8/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.7043 - loss: 0.5397 - val_accuracy: 0.7435 - val_loss: 0.5162\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.7109 - loss: 0.5362 - val_accuracy: 0.7435 - val_loss: 0.5141\n",
"Epoch 9/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.7172 - loss: 0.5372 - val_accuracy: 0.7496 - val_loss: 0.5146\n",
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.7142 - loss: 0.5267 - val_accuracy: 0.7357 - val_loss: 0.5125\n",
"Epoch 10/10\n",
- "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 14ms/step - accuracy: 0.7337 - loss: 0.5232 - val_accuracy: 0.7409 - val_loss: 0.5131\n"
+ "\u001b[1m37/37\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m1s\u001b[0m 15ms/step - accuracy: 0.7102 - loss: 0.5249 - val_accuracy: 0.7383 - val_loss: 0.5094\n"
]
},
{
"data": {
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 23,
+ "execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@@ -713,7 +800,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 25,
"id": "fbccebaa-fc24-4a58-a032-222cef8fdf08",
"metadata": {},
"outputs": [
@@ -721,8 +808,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.7480 - loss: 0.5028 \n",
- "Accuracy 0.753032922744751\n"
+ "\u001b[1m1/5\u001b[0m \u001b[32m━━━━\u001b[0m\u001b[37m━━━━━━━━━━━━━━━━\u001b[0m \u001b[1m0s\u001b[0m 11ms/step - accuracy: 0.7812 - loss: 0.4536"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\u001b[1m5/5\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 7ms/step - accuracy: 0.7601 - loss: 0.4909 \n",
+ "Accuracy 0.7478336095809937\n"
]
}
],
@@ -736,17 +830,9 @@
"id": "7534616c-8561-4869-b6e9-7254ebdb2c3f",
"metadata": {},
"source": [
- "## Save and Reload Model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "id": "52425a31-7f21-415e-b166-7682c7eb282c",
- "metadata": {},
- "outputs": [],
- "source": [
- "import tensorflow as tf"
+ "### Save and reload model\n",
+ "\n",
+ "Demonstrate saving the trained model and reloading it for inference."
]
},
{
@@ -756,7 +842,7 @@
"metadata": {},
"outputs": [],
"source": [
- "model.save('my_pet_classifier.keras')"
+ "model.save('models/my_pet_classifier.keras')"
]
},
{
@@ -766,7 +852,7 @@
"metadata": {},
"outputs": [],
"source": [
- "reloaded_model = tf.keras.models.load_model('my_pet_classifier.keras')"
+ "reloaded_model = tf.keras.models.load_model('models/my_pet_classifier.keras')"
]
},
{
@@ -779,8 +865,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 26ms/step\n",
- "This particular pet had a 81.1 percent probability of getting adopted.\n"
+ "\u001b[1m1/1\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 28ms/step\n",
+ "This particular pet had a 78.4 percent probability of getting adopted.\n"
]
}
],
@@ -826,202 +912,241 @@
"metadata": {},
"outputs": [],
"source": [
+ "from pyspark.sql.functions import col, struct, pandas_udf\n",
+ "from pyspark.ml.functions import predict_batch_udf\n",
+ "from pyspark.sql.types import *\n",
+ "from pyspark.sql import SparkSession\n",
"from pyspark import SparkConf\n",
- "from pyspark.sql import SparkSession"
+ "import json\n",
+ "import pandas as pd"
]
},
{
- "cell_type": "code",
- "execution_count": null,
- "id": "60dff1da",
+ "cell_type": "markdown",
+ "id": "bb5aa875",
"metadata": {},
- "outputs": [],
"source": [
- "import os\n",
- "conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
- "\n",
- "conf = SparkConf()\n",
- "if 'spark' not in globals():\n",
- " # If Spark is not already started with Jupyter, attach to Spark Standalone\n",
- " import socket\n",
- " hostname = socket.gethostname()\n",
- " conf.setMaster(f\"spark://{hostname}:7077\") # assuming Master is on default port 7077\n",
- "conf.set(\"spark.task.maxFailures\", \"1\")\n",
- "conf.set(\"spark.driver.memory\", \"8g\")\n",
- "conf.set(\"spark.executor.memory\", \"8g\")\n",
- "conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
- "conf.set(\"spark.sql.execution.pyspark.udf.simplifiedTraceback.enabled\", \"false\")\n",
- "conf.set(\"spark.sql.pyspark.jvmStacktrace.enabled\", \"true\")\n",
- "conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
- "conf.set(\"spark.python.worker.reuse\", \"true\")\n",
- "# Create Spark Session\n",
- "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
- "sc = spark.sparkContext"
+ "Check the cluster environment to handle any platform-specific Spark configurations."
]
},
{
"cell_type": "code",
- "execution_count": 31,
- "id": "3c64fd7b-3d1e-40f8-ab64-b5c13f8bbe77",
+ "execution_count": 30,
+ "id": "7701420e",
"metadata": {},
"outputs": [],
"source": [
- "df = spark.createDataFrame(dataframe).repartition(8)"
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5e231dbd",
+ "metadata": {},
+ "source": [
+ "#### Create Spark Session\n",
+ "\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
]
},
{
"cell_type": "code",
- "execution_count": 32,
- "id": "1be8215b-5068-41b4-849c-1c3ea7bb108a",
+ "execution_count": 31,
+ "id": "60dff1da",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "25/01/06 21:14:57 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/06 21:14:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "Setting default log level to \"WARN\".\n",
+ "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
+ "25/01/06 21:14:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
]
}
],
"source": [
- "df.write.mode(\"overwrite\").parquet(\"datasets/petfinder-mini\")"
+ "conf = SparkConf()\n",
+ "\n",
+ "if 'spark' not in globals():\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " \n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " source = \"/usr/lib/x86_64-linux-gnu/libstdc++.so.6\"\n",
+ " target = f\"{conda_env}/lib/libstdc++.so.6\"\n",
+ " try:\n",
+ " if os.path.islink(target) or os.path.exists(target):\n",
+ " os.remove(target)\n",
+ " os.symlink(source, target)\n",
+ " except OSError as e:\n",
+ " print(f\"Error creating symlink: {e}\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\")\n",
+ " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"1000\")\n",
+ "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
+ "sc = spark.sparkContext"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fa2333d1",
+ "metadata": {},
+ "source": [
+ "#### Create PySpark DataFrame"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "3c64fd7b-3d1e-40f8-ab64-b5c13f8bbe77",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = spark.createDataFrame(dataframe).repartition(8)"
]
},
{
"cell_type": "code",
"execution_count": 33,
- "id": "d4dbde99-cf65-4c15-a163-754a0201a48d",
+ "id": "1be8215b-5068-41b4-849c-1c3ea7bb108a",
"metadata": {},
"outputs": [
{
- "name": "stdout",
+ "name": "stderr",
"output_type": "stream",
"text": [
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|\n",
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "| Cat| 3| Tabby| Male| Black| White| Small| Short| No| No|Healthy|100| 1| 1|\n",
- "| Cat| 1|Domestic Medium Hair| Male| Black| Brown| Medium| Medium| Not Sure| Not Sure|Healthy| 0| 2| 1|\n",
- "| Dog| 1| Mixed Breed| Male| Brown| White| Medium| Medium| Yes| No|Healthy| 0| 7| 1|\n",
- "| Dog| 4| Mixed Breed|Female| Black| Brown| Medium| Short| Yes| No|Healthy|150| 8| 1|\n",
- "| Dog| 1| Mixed Breed| Male| Black|No Color| Medium| Short| No| No|Healthy| 0| 3| 1|\n",
- "| Cat| 3| Domestic Short Hair|Female| Cream| Gray| Medium| Short| No| No|Healthy| 0| 2| 1|\n",
- "| Cat| 12| Domestic Long Hair| Male| Black|No Color| Medium| Long| No| Not Sure|Healthy|300| 3| 1|\n",
- "| Cat| 2|Domestic Medium Hair|Female| Gray|No Color| Medium| Medium| No| No|Healthy| 0| 6| 1|\n",
- "| Cat| 12|Domestic Medium Hair|Female| Black| White| Medium| Medium| Not Sure| Not Sure|Healthy| 0| 2| 0|\n",
- "| Dog| 2| Mixed Breed| Male| Black| Brown| Medium| Short| No| No|Healthy| 0| 7| 1|\n",
- "| Cat| 3| Domestic Long Hair|Female| Black| Brown| Large| Long| Yes| No|Healthy| 50| 2| 1|\n",
- "| Dog| 2| Mixed Breed| Male| Brown| Cream| Medium| Long| Yes| No|Healthy| 0| 1| 1|\n",
- "| Dog| 3| Mixed Breed|Female| Brown| Cream| Medium| Medium| Not Sure| Not Sure|Healthy| 0| 2| 1|\n",
- "| Dog| 78| Terrier| Male| Black| White| Medium| Medium| Not Sure| Not Sure|Healthy| 0| 2| 0|\n",
- "| Cat| 6| Domestic Short Hair|Female| Brown|No Color| Small| Short| Yes| Yes|Healthy| 0| 1| 1|\n",
- "| Dog| 8| Mixed Breed|Female| Brown|No Color| Medium| Short| No| Yes|Healthy| 10| 2| 0|\n",
- "| Dog| 2| Mixed Breed|Female| Black|No Color| Medium| Short| No| No|Healthy| 0| 8| 1|\n",
- "| Dog| 12| Mixed Breed|Female| Brown| White| Medium| Medium| No| Yes|Healthy| 0| 7| 1|\n",
- "| Dog| 10| Mixed Breed|Female| Black| Brown| Medium| Medium| Yes| Yes|Healthy| 0| 0| 0|\n",
- "| Cat| 3| Domestic Short Hair| Male| Brown| White| Small| Short| No| No|Healthy| 0| 19| 1|\n",
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "only showing top 20 rows\n",
- "\n"
+ " \r"
]
}
],
"source": [
- "df.show()"
+ "data_path = \"spark-dl-datasets/petfinder-mini\"\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
]
},
{
"cell_type": "markdown",
- "id": "efa3e424-2920-44eb-afa0-885e40b620ed",
+ "id": "7cec4e0e",
"metadata": {},
"source": [
- "## Inference using Spark DL API"
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 34,
- "id": "4c21296c-20ed-43f8-921a-c85a820d1819",
+ "id": "0892f845",
"metadata": {},
"outputs": [],
"source": [
- "import numpy as np\n",
- "import os\n",
- "\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import struct, col\n",
- "from pyspark.sql.types import ArrayType, FloatType"
+ "df = spark.read.parquet(data_path).cache()"
]
},
{
"cell_type": "code",
"execution_count": 35,
- "id": "04b38f3a-70ea-4746-9f52-c50087401508",
+ "id": "952645dd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|\n",
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "| Cat| 1|Domestic Medium Hair|Female| White|No Color| Small| Medium| No| No|Healthy| 0| 2| 1|\n",
- "| Dog| 2| Mixed Breed|Female| Black| Brown| Medium| Medium| No| No|Healthy| 0| 3| 1|\n",
- "| Dog| 18| Dalmatian|Female| Black| White| Medium| Medium| Yes| No|Healthy|350| 5| 1|\n",
- "| Dog| 3| Mixed Breed|Female| Black|No Color| Medium| Short| No| No|Healthy| 0| 1| 0|\n",
- "| Dog| 2| Mixed Breed| Male| Black| Brown| Medium| Short| No| No|Healthy| 0| 1| 1|\n",
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "only showing top 5 rows\n",
- "\n"
+ "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\n"
]
}
],
"source": [
- "df = spark.read.parquet(\"datasets/petfinder-mini\").cache()\n",
- "df.show(5)"
+ "columns = df.columns\n",
+ "print(columns)"
]
},
{
"cell_type": "code",
"execution_count": 36,
- "id": "29c27243-7c74-4045-aaf1-f75a322c0530",
+ "id": "b9c24c0d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\n"
+ "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt']\n"
]
}
],
"source": [
- "columns = df.columns\n",
+ "# remove label column\n",
+ "columns.remove(\"target\")\n",
"print(columns)"
]
},
{
"cell_type": "code",
"execution_count": 37,
- "id": "47508b14-97fa-42ee-a7d0-6175e6408283",
- "metadata": {
- "tags": []
- },
+ "id": "d4dbde99-cf65-4c15-a163-754a0201a48d",
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt']\n"
+ "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
+ "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|\n",
+ "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
+ "| Dog| 3| Mixed Breed| Male| Black|No Color| Small| Medium| Not Sure| Not Sure|Healthy| 0| 2| 0|\n",
+ "| Dog| 9| Mixed Breed| Male| Gray|No Color| Medium| Short| Not Sure| No|Healthy| 0| 4| 1|\n",
+ "| Cat| 4| Domestic Short Hair| Male| Black| Gray| Medium| Short| Not Sure| Not Sure|Healthy| 0| 4| 1|\n",
+ "| Cat| 6| Domestic Short Hair| Male|Yellow| White| Medium| Short| No| No|Healthy| 0| 3| 1|\n",
+ "| Cat| 6|Domestic Medium Hair| Male| Gray|No Color| Small| Medium| No| No|Healthy| 0| 4| 1|\n",
+ "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
+ "only showing top 5 rows\n",
+ "\n"
]
}
],
"source": [
- "# remove label column\n",
- "columns.remove(\"target\")\n",
- "print(columns)"
+ "df.show(5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "824d7f97",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
]
},
{
@@ -1032,7 +1157,21 @@
"outputs": [],
"source": [
"# get absolute path to model\n",
- "model_dir = \"{}/my_pet_classifier.keras\".format(os.getcwd())"
+ "model_path = \"{}/models/my_pet_classifier.keras\".format(os.getcwd())\n",
+ "\n",
+ "# For cloud environments, copy the model to the distributed file system.\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
+ " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/my_pet_classifier.keras\"\n",
+ " shutil.copy(model_path, dbfs_model_path)\n",
+ " model_path = dbfs_model_path\n",
+ "elif on_dataproc:\n",
+ " # GCS is mounted at /mnt/gcs by the init script\n",
+ " models_dir = \"/mnt/gcs/spark-dl/models\"\n",
+ " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
+ " gcs_model_path = models_dir + \"/my_pet_classifier.keras\"\n",
+ " shutil.copy(model_path, gcs_model_path)\n",
+ " model_path = gcs_model_path"
]
},
{
@@ -1055,7 +1194,7 @@
" except RuntimeError as e:\n",
" print(e)\n",
"\n",
- " model = tf.keras.models.load_model(model_dir)\n",
+ " model = tf.keras.models.load_model(model_path)\n",
"\n",
" def predict(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\n",
" inputs = {\n",
@@ -1102,16 +1241,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "24/10/03 17:39:09 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n",
- "[Stage 4:> (0 + 8) / 8]\r"
+ "25/01/06 21:15:01 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n",
+ "[Stage 5:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 24.1 ms, sys: 6.37 ms, total: 30.4 ms\n",
- "Wall time: 4.58 s\n"
+ "CPU times: user 17.1 ms, sys: 7.86 ms, total: 25 ms\n",
+ "Wall time: 4.81 s\n"
]
},
{
@@ -1138,15 +1277,15 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 5:> (0 + 8) / 8]\r"
+ "[Stage 6:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 98.2 ms, sys: 8.34 ms, total: 107 ms\n",
- "Wall time: 1.57 s\n"
+ "CPU times: user 79.2 ms, sys: 16.6 ms, total: 95.8 ms\n",
+ "Wall time: 1.59 s\n"
]
},
{
@@ -1173,14 +1312,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 6:> (0 + 8) / 8]\r"
+ "[Stage 7:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 17.6 ms, sys: 2.19 ms, total: 19.8 ms\n",
+ "CPU times: user 19.5 ms, sys: 5.43 ms, total: 25 ms\n",
"Wall time: 1.51 s\n"
]
},
@@ -1213,26 +1352,26 @@
"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n",
"|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target| preds|\n",
"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n",
- "| Cat| 1|Domestic Medium Hair|Female| White|No Color| Small| Medium| No| No| Healthy| 0| 2| 1| 1.9543833|\n",
- "| Dog| 2| Mixed Breed|Female| Black| Brown| Medium| Medium| No| No| Healthy| 0| 3| 1| 1.7454995|\n",
- "| Dog| 18| Dalmatian|Female| Black| White| Medium| Medium| Yes| No| Healthy|350| 5| 1| 1.5183508|\n",
- "| Dog| 3| Mixed Breed|Female| Black|No Color| Medium| Short| No| No| Healthy| 0| 1| 0| 0.67013955|\n",
- "| Dog| 2| Mixed Breed| Male| Black| Brown| Medium| Short| No| No| Healthy| 0| 1| 1| 1.5492269|\n",
- "| Dog| 36| Mixed Breed| Male| Brown|No Color| Medium| Medium| Not Sure| Yes|Minor Injury| 0| 1| 0|-0.27595556|\n",
- "| Cat| 6| Domestic Short Hair|Female| Black| White| Medium| Short| Yes| No| Healthy| 0| 5| 0| 0.7229555|\n",
- "| Dog| 72| Golden Retriever|Female| Cream|No Color| Large| Medium| Yes| Not Sure| Healthy| 0| 1| 1| 0.7397226|\n",
- "| Cat| 2| Domestic Short Hair| Male| Black| White| Small| Short| No| Yes| Healthy| 0| 4| 1| 1.4016482|\n",
- "| Dog| 3| Irish Terrier| Male| Brown| Cream| Medium| Medium| Yes| No|Minor Injury|200| 3| 0| 1.3436754|\n",
- "| Dog| 2| Mixed Breed|Female| White|No Color| Medium| Medium| Yes| No| Healthy| 0| 1| 1| 1.156439|\n",
- "| Dog| 2| Mixed Breed| Male| Brown|No Color| Medium| Medium| Yes| No| Healthy| 0| 4| 1| 1.7760799|\n",
- "| Cat| 2| Domestic Short Hair| Male| Black| Gray| Medium| Short| No| No| Healthy| 0| 2| 1| 1.8319463|\n",
- "| Dog| 1| German Shepherd Dog| Male| Gray|No Color| Medium| Medium| No| No| Healthy| 0| 6| 1| 2.5471144|\n",
- "| Dog| 24| Golden Retriever| Male|Yellow| Cream| Medium| Long| Yes| Yes| Healthy| 0| 7| 1| 1.4675076|\n",
- "| Dog| 1| Mixed Breed|Female| Black| Brown| Small| Medium| No| Yes| Healthy| 0| 1| 0| 0.8451028|\n",
- "| Cat| 12| Tuxedo| Male| Black|No Color| Small| Medium| Yes| Yes| Healthy| 50| 1| 1| 0.6487097|\n",
- "| Cat| 3| Domestic Short Hair|Female| Black|No Color| Small| Short| No| No| Healthy| 0| 1| 1| 1.0688435|\n",
- "| Dog| 2| Mixed Breed|Female| Brown| White| Medium| Short| No| No| Healthy| 0| 1| 1| 1.4086031|\n",
- "| Dog| 11| Mixed Breed|Female|Golden|No Color| Medium| Short| Yes| Yes| Healthy| 0| 9| 1| 0.28429908|\n",
+ "| Dog| 3| Mixed Breed| Male| Black|No Color| Small| Medium| Not Sure| Not Sure| Healthy| 0| 2| 0| 0.39443576|\n",
+ "| Dog| 9| Mixed Breed| Male| Gray|No Color| Medium| Short| Not Sure| No| Healthy| 0| 4| 1| 0.7342956|\n",
+ "| Cat| 4| Domestic Short Hair| Male| Black| Gray| Medium| Short| Not Sure| Not Sure| Healthy| 0| 4| 1| 0.47965235|\n",
+ "| Cat| 6| Domestic Short Hair| Male|Yellow| White| Medium| Short| No| No| Healthy| 0| 3| 1| 1.0879391|\n",
+ "| Cat| 6|Domestic Medium Hair| Male| Gray|No Color| Small| Medium| No| No| Healthy| 0| 4| 1| 1.0963857|\n",
+ "| Cat| 5|Domestic Medium Hair|Female| Gray|No Color| Medium| Medium| Yes| Not Sure| Healthy| 0| 1| 0| -0.2900008|\n",
+ "| Dog| 24| Beagle|Female| Black| Golden| Medium| Short| Not Sure| Not Sure|Minor Injury| 0| 1| 1| -0.4026273|\n",
+ "| Cat| 29| Tabby| Male| Brown| Golden| Medium| Short| No| No| Healthy| 0| 1| 0| 0.5199422|\n",
+ "| Dog| 9| Mixed Breed|Female| Black| Brown| Medium| Short| Yes| Yes| Healthy| 0| 2| 0|-0.13055041|\n",
+ "| Dog| 2| Mixed Breed|Female| Cream| White| Medium| Short| No| No| Healthy| 0| 1| 0| 1.5426368|\n",
+ "| Dog| 2| Mixed Breed| Male| Brown| White| Medium| Short| Yes| No| Healthy| 0| 1| 1| 1.3627914|\n",
+ "| Dog| 60| Golden Retriever| Male| Brown| Yellow| Medium| Medium| Yes| Yes| Healthy| 0| 5| 1| 0.76058775|\n",
+ "| Cat| 9| Siamese| Male| White|No Color| Medium| Short| Yes| No| Healthy| 0| 2| 1| 1.1733786|\n",
+ "| Dog| 19| Doberman Pinscher|Female| Black| Brown| Large| Short| Yes| Yes| Healthy|500| 2| 1| 0.72322303|\n",
+ "| Cat| 11| Domestic Short Hair| Male| Cream|No Color| Medium| Short| Yes| Yes| Healthy|100| 6| 0| 0.447542|\n",
+ "| Dog| 18| Mixed Breed|Female| Brown| White| Small| Short| Yes| No| Healthy| 0| 5| 0| 0.6399308|\n",
+ "| Dog| 4| Mixed Breed|Female| Brown| White| Medium| Medium| Not Sure| Not Sure| Healthy| 0| 3| 0|0.027931072|\n",
+ "| Dog| 96| Golden Retriever| Male|Golden|No Color| Large| Long| Yes| Yes| Healthy| 0| 2| 1| 0.79285514|\n",
+ "| Dog| 54| Golden Retriever| Male|Golden|No Color| Large| Medium| Yes| No| Healthy|350| 20| 1| 3.0215485|\n",
+ "| Cat| 5|Domestic Medium Hair|Female| Brown| White| Medium| Medium| No| No| Healthy| 0| 5| 1| 1.0509686|\n",
"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n",
"only showing top 20 rows\n",
"\n"
@@ -1245,85 +1384,162 @@
},
{
"cell_type": "markdown",
- "id": "467b02a1-9f08-4fe8-a99c-581b7a01b8f6",
+ "id": "0c3e0390",
"metadata": {},
"source": [
- "### Using Triton Inference Server\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
+ "\n",
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
"\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
+ ""
]
},
{
- "cell_type": "markdown",
- "id": "22d1805b-7cac-4b27-9359-7a25b4ef3f71",
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "2605d134-ef75-4d94-9b16-2c6d85f29bef",
"metadata": {},
- "source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n tf-gpu -c conda-forge python=3.10.0\n",
- "conda activate tf-gpu\n",
- "\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 tensorflow[and-cuda] conda-pack\n",
- "\n",
- "conda-pack # tf-gpu.tar.gz\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 45,
- "id": "2605d134-ef75-4d94-9b16-2c6d85f29bef",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
"outputs": [],
"source": [
- "import numpy as np\n",
- "import os\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct\n",
- "from pyspark.sql.types import ArrayType, FloatType"
+ "from functools import partial"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "4666e618-8038-4dc5-9be7-793aedbf4500",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "sudo: a terminal is required to read the password; either use the -S option to read from standard input or configure an askpass helper\n",
- "sudo: a password is required\n"
- ]
- }
- ],
+ "metadata": {},
+ "outputs": [],
"source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "sudo rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/feature_columns models\n",
+ "def triton_server(ports, model_path):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import tensorflow as tf\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
+ "\n",
+ " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
+ " # Enable GPU memory growth\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ "\n",
+ " model = tf.keras.models.load_model(model_path)\n",
+ "\n",
+ " def decode(input_tensor):\n",
+ " return tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(input_tensor))\n",
"\n",
- "# add custom execution environment\n",
- "cp tf-gpu.tar.gz models"
+ " def identity(input_tensor):\n",
+ " return tf.convert_to_tensor(input_tensor)\n",
+ "\n",
+ " input_transforms = {\n",
+ " \"Type\": decode,\n",
+ " \"Age\": identity,\n",
+ " \"Breed1\": decode,\n",
+ " \"Gender\": decode,\n",
+ " \"Color1\": decode,\n",
+ " \"Color2\": decode,\n",
+ " \"MaturitySize\": decode,\n",
+ " \"FurLength\": decode,\n",
+ " \"Vaccinated\": decode,\n",
+ " \"Sterilized\": decode,\n",
+ " \"Health\": decode,\n",
+ " \"Fee\": identity,\n",
+ " \"PhotoAmt\": identity\n",
+ " }\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " decoded_inputs = {k: input_transforms[k](v) for k, v in inputs.items()}\n",
+ " print(f\"SERVER: Received batch of size {len(decoded_inputs['Type'])}.\")\n",
+ " return {\n",
+ " \"preds\": model.predict(decoded_inputs)\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"PetClassifier\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"Type\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Age\", dtype=np.int64, shape=(-1,)),\n",
+ " Tensor(name=\"Breed1\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Gender\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Color1\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Color2\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"MaturitySize\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"FurLength\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Vaccinated\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Sterilized\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Health\", dtype=np.bytes_, shape=(-1,)),\n",
+ " Tensor(name=\"Fee\", dtype=np.int64, shape=(-1,)),\n",
+ " Tensor(name=\"PhotoAmt\", dtype=np.int64, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=128,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name, model_path):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports, model_path,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
]
},
{
"cell_type": "markdown",
- "id": "91bd1003-46c7-42d1-ab4d-869e52d62146",
+ "id": "617525a5",
"metadata": {},
"source": [
- "#### Start Triton Server on each executor"
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
]
},
{
@@ -1331,343 +1547,305 @@
"execution_count": 47,
"id": "a7fb146c-5319-4831-85f7-f2f3c084b042",
"metadata": {
- "scrolled": true,
- "tags": [
- "TRITON"
- ]
+ "scrolled": true
},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
- ]
- },
- {
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 47,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "my_pet_classifier_dir = \"{}/my_pet_classifier.keras\".format(os.getcwd())\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
+ "def _use_stage_level_scheduling(spark, rdd):\n",
"\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"128M\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " my_pet_classifier_dir: {\"bind\": \"/my_pet_classifier.keras\", \"mode\": \"ro\"}\n",
- " }\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- "\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " ready = False\n",
- " while not ready:\n",
- " try:\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " time.sleep(5)\n",
- " \n",
- " return [True]\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
"\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ " return rdd.withResources(rp)"
]
},
{
"cell_type": "markdown",
- "id": "b75e6f20-f06c-4f4c-ada1-c562e078ed4b",
+ "id": "08095b39",
"metadata": {},
"source": [
- "#### Run inference"
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
]
},
{
"cell_type": "code",
"execution_count": 48,
- "id": "fe8dc3e6-f1b1-4a24-85f4-0a5ecabef4c5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "id": "1d8c358a",
+ "metadata": {},
"outputs": [],
"source": [
- "df = spark.read.parquet(\"datasets/petfinder-mini\")"
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
]
},
{
"cell_type": "code",
"execution_count": 49,
- "id": "ce92f041-930f-48ed-9a03-19f6c249ca27",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "id": "c9b98208",
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target|\n",
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "| Cat| 1|Domestic Medium Hair|Female| White|No Color| Small| Medium| No| No|Healthy| 0| 2| 1|\n",
- "| Dog| 2| Mixed Breed|Female| Black| Brown| Medium| Medium| No| No|Healthy| 0| 3| 1|\n",
- "| Dog| 18| Dalmatian|Female| Black| White| Medium| Medium| Yes| No|Healthy|350| 5| 1|\n",
- "| Dog| 3| Mixed Breed|Female| Black|No Color| Medium| Short| No| No|Healthy| 0| 1| 0|\n",
- "| Dog| 2| Mixed Breed| Male| Black| Brown| Medium| Short| No| No|Healthy| 0| 1| 1|\n",
- "+----+---+--------------------+------+------+--------+------------+---------+----------+----------+-------+---+--------+------+\n",
- "only showing top 5 rows\n",
- "\n"
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
]
}
],
"source": [
- "df.show(5)"
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ee5a2d8b",
+ "metadata": {},
+ "source": [
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
"execution_count": 50,
- "id": "4cfb3f34-a215-4781-91bf-2bec85e15633",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "id": "1006ba89",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 51,
+ "id": "918f14b8",
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt', 'target']\n"
+ "Using ports [7000, 7001, 7002]\n"
]
}
],
"source": [
- "columns = df.columns\n",
- "print(columns)"
+ "model_name = \"PetClassifier\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
]
},
{
"cell_type": "code",
- "execution_count": 51,
- "id": "b315ee72-62af-476b-a994-0dba72d5f96e",
- "metadata": {
- "scrolled": true,
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 52,
+ "id": "dc4ff00f",
+ "metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 9:> (0 + 1) / 1]\r"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "['Type', 'Age', 'Breed1', 'Gender', 'Color1', 'Color2', 'MaturitySize', 'FurLength', 'Vaccinated', 'Sterilized', 'Health', 'Fee', 'PhotoAmt']\n"
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2806426\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
]
}
],
"source": [
- "# remove label column\n",
- "columns.remove(\"target\")\n",
- "print(columns)"
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name, model_path)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cb560288",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
]
},
{
"cell_type": "code",
- "execution_count": 52,
- "id": "da004eca-f7ad-4ee3-aa88-a6a20c1b72e5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 53,
+ "id": "3eec95bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "url = f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
+ "id": "e50b5fc8",
+ "metadata": {},
"outputs": [],
"source": [
- "def triton_fn(triton_uri, model_name):\n",
+ "def triton_fn(url, model_name):\n",
" import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
+ " from pytriton.client import ModelClient\n",
"\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\n",
- " # convert input ndarrays into a dictionary of ndarrays\n",
- " inputs = {\n",
- " \"Type\": t, \n",
- " \"Age\": a, \n",
- " \"Breed1\": b, \n",
- " \"Gender\": g,\n",
- " \"Color1\": c1,\n",
- " \"Color2\": c2,\n",
- " \"MaturitySize\": m,\n",
- " \"FurLength\": f,\n",
- " \"Vaccinated\": v, \n",
- " \"Sterilized\": s,\n",
- " \"Health\": h,\n",
- " \"Fee\": fee,\n",
- " \"PhotoAmt\": p\n",
- " }\n",
- " return _predict(inputs)\n",
- " \n",
- " def _predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
+ " print(f\"CLIENT: Connecting to {model_name} at {url}\")\n",
+ "\n",
+ " def infer_batch(t, a, b, g, c1, c2, m, f, v, s, h, fee, p):\n",
" \n",
- " return predict"
+ " def encode(value):\n",
+ " return np.vectorize(lambda x: x.encode(\"utf-8\"))(value).astype(np.bytes_)\n",
+ "\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " encoded_inputs = {\n",
+ " \"Type\": encode(t), \n",
+ " \"Age\": a, \n",
+ " \"Breed1\": encode(b), \n",
+ " \"Gender\": encode(g),\n",
+ " \"Color1\": encode(c1),\n",
+ " \"Color2\": encode(c2),\n",
+ " \"MaturitySize\": encode(m),\n",
+ " \"FurLength\": encode(f),\n",
+ " \"Vaccinated\": encode(v),\n",
+ " \"Sterilized\": encode(s),\n",
+ " \"Health\": encode(h),\n",
+ " \"Fee\": fee,\n",
+ " \"PhotoAmt\": p\n",
+ " }\n",
+ " result_data = client.infer_batch(**encoded_inputs)\n",
+ " return result_data[\"preds\"]\n",
+ " \n",
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2edd887f",
+ "metadata": {},
+ "source": [
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 53,
- "id": "2ffb020e-dc93-456b-bee6-405611eee1e1",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 55,
+ "id": "fe8dc3e6-f1b1-4a24-85f4-0a5ecabef4c5",
+ "metadata": {},
"outputs": [],
"source": [
- "from functools import partial\n",
- "\n",
- "# need to pass the list of columns into the model_udf\n",
- "classify = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"feature_columns\"),\n",
- " input_tensor_shapes=[[1]] * len(columns),\n",
- " return_type=FloatType(),\n",
- " batch_size=1024)"
+ "df = spark.read.parquet(data_path)"
]
},
{
"cell_type": "code",
- "execution_count": 54,
- "id": "7657f820-5ec2-4ac8-a107-4b58773d204a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "+----+---+----------+------+------+--------+------------+---------+----------+----------+----------+---+--------+------+----------+\n",
- "|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target| preds|\n",
- "+----+---+----------+------+------+--------+------------+---------+----------+----------+----------+---+--------+------+----------+\n",
- "| Cat| 1|Domesti...|Female| White|No Color| Small| Medium| No| No| Healthy| 0| 2| 1| 1.9543833|\n",
- "| Dog| 2|Mixed B...|Female| Black| Brown| Medium| Medium| No| No| Healthy| 0| 3| 1| 1.7454995|\n",
- "| Dog| 18| Dalmatian|Female| Black| White| Medium| Medium| Yes| No| Healthy|350| 5| 1| 1.5183508|\n",
- "| Dog| 3|Mixed B...|Female| Black|No Color| Medium| Short| No| No| Healthy| 0| 1| 0|0.67013955|\n",
- "| Dog| 2|Mixed B...| Male| Black| Brown| Medium| Short| No| No| Healthy| 0| 1| 1| 1.5492269|\n",
- "| Dog| 36|Mixed B...| Male| Brown|No Color| Medium| Medium| Not Sure| Yes|Minor I...| 0| 1| 0|-0.2759...|\n",
- "| Cat| 6|Domesti...|Female| Black| White| Medium| Short| Yes| No| Healthy| 0| 5| 0| 0.7229555|\n",
- "| Dog| 72|Golden ...|Female| Cream|No Color| Large| Medium| Yes| Not Sure| Healthy| 0| 1| 1| 0.7397226|\n",
- "| Cat| 2|Domesti...| Male| Black| White| Small| Short| No| Yes| Healthy| 0| 4| 1| 1.4016482|\n",
- "| Dog| 3|Irish T...| Male| Brown| Cream| Medium| Medium| Yes| No|Minor I...|200| 3| 0| 1.3436754|\n",
- "| Dog| 2|Mixed B...|Female| White|No Color| Medium| Medium| Yes| No| Healthy| 0| 1| 1| 1.156439|\n",
- "| Dog| 2|Mixed B...| Male| Brown|No Color| Medium| Medium| Yes| No| Healthy| 0| 4| 1| 1.7760799|\n",
- "| Cat| 2|Domesti...| Male| Black| Gray| Medium| Short| No| No| Healthy| 0| 2| 1| 1.8319463|\n",
- "| Dog| 1|German ...| Male| Gray|No Color| Medium| Medium| No| No| Healthy| 0| 6| 1| 2.5471144|\n",
- "| Dog| 24|Golden ...| Male|Yellow| Cream| Medium| Long| Yes| Yes| Healthy| 0| 7| 1| 1.4675076|\n",
- "| Dog| 1|Mixed B...|Female| Black| Brown| Small| Medium| No| Yes| Healthy| 0| 1| 0| 0.8451028|\n",
- "| Cat| 12| Tuxedo| Male| Black|No Color| Small| Medium| Yes| Yes| Healthy| 50| 1| 1| 0.6487097|\n",
- "| Cat| 3|Domesti...|Female| Black|No Color| Small| Short| No| No| Healthy| 0| 1| 1| 1.0688435|\n",
- "| Dog| 2|Mixed B...|Female| Brown| White| Medium| Short| No| No| Healthy| 0| 1| 1| 1.4086031|\n",
- "| Dog| 11|Mixed B...|Female|Golden|No Color| Medium| Short| Yes| Yes| Healthy| 0| 9| 1|0.28429908|\n",
- "+----+---+----------+------+------+--------+------------+---------+----------+----------+----------+---+--------+------+----------+\n",
- "only showing top 20 rows\n",
- "\n"
- ]
- }
- ],
+ "execution_count": 56,
+ "id": "4cfb3f34-a215-4781-91bf-2bec85e15633",
+ "metadata": {},
+ "outputs": [],
"source": [
- "# WITHOUT custom python backend, FAILS with: Op type not registered 'DenseBincount' \n",
- "df.withColumn(\"preds\", classify(struct(*columns))).show(truncate=10)"
+ "columns = df.columns\n",
+ "# remove label column\n",
+ "columns.remove(\"target\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b75e6f20-f06c-4f4c-ada1-c562e078ed4b",
+ "metadata": {},
+ "source": [
+ "#### Run inference"
]
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 57,
+ "id": "2ffb020e-dc93-456b-bee6-405611eee1e1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# need to pass the list of columns into the model_udf\n",
+ "classify = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
+ " input_tensor_shapes=[[1]] * len(columns),\n",
+ " return_type=FloatType(),\n",
+ " batch_size=64)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
"id": "e6ff0356-becd-421f-aebb-272497d5ad6a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "[Stage 11:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 17.2 ms, sys: 2.85 ms, total: 20 ms\n",
- "Wall time: 2.5 s\n"
+ "CPU times: user 22.5 ms, sys: 5.09 ms, total: 27.6 ms\n",
+ "Wall time: 8.46 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
]
}
],
@@ -1679,34 +1857,23 @@
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": 59,
"id": "ce18ee7c-5958-4986-b200-6d986fcc6243",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "[Stage 13:==================================================> (7 + 1) / 8]\r"
+ " \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 23.4 ms, sys: 984 μs, total: 24.4 ms\n",
- "Wall time: 2.5 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 14.9 ms, sys: 4.68 ms, total: 19.6 ms\n",
+ "Wall time: 5.81 s\n"
]
}
],
@@ -1718,27 +1885,30 @@
},
{
"cell_type": "code",
- "execution_count": 57,
+ "execution_count": 60,
"id": "0888ce40-b2c4-4aed-8ccb-6a8bcd00abc8",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "[Stage 13:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 14.7 ms, sys: 4.61 ms, total: 19.3 ms\n",
- "Wall time: 2.47 s\n"
+ "CPU times: user 87.8 ms, sys: 4.62 ms, total: 92.5 ms\n",
+ "Wall time: 5.96 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
]
}
],
@@ -1750,13 +1920,9 @@
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": 61,
"id": "d45812b5-f584-41a4-a821-2b59e065671c",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
@@ -1765,26 +1931,26 @@
"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n",
"|Type|Age| Breed1|Gender|Color1| Color2|MaturitySize|FurLength|Vaccinated|Sterilized| Health|Fee|PhotoAmt|target| preds|\n",
"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n",
- "| Cat| 1|Domestic Medium Hair|Female| White|No Color| Small| Medium| No| No| Healthy| 0| 2| 1| 1.9543833|\n",
- "| Dog| 2| Mixed Breed|Female| Black| Brown| Medium| Medium| No| No| Healthy| 0| 3| 1| 1.7454995|\n",
- "| Dog| 18| Dalmatian|Female| Black| White| Medium| Medium| Yes| No| Healthy|350| 5| 1| 1.5183508|\n",
- "| Dog| 3| Mixed Breed|Female| Black|No Color| Medium| Short| No| No| Healthy| 0| 1| 0| 0.67013955|\n",
- "| Dog| 2| Mixed Breed| Male| Black| Brown| Medium| Short| No| No| Healthy| 0| 1| 1| 1.5492269|\n",
- "| Dog| 36| Mixed Breed| Male| Brown|No Color| Medium| Medium| Not Sure| Yes|Minor Injury| 0| 1| 0|-0.27595556|\n",
- "| Cat| 6| Domestic Short Hair|Female| Black| White| Medium| Short| Yes| No| Healthy| 0| 5| 0| 0.7229555|\n",
- "| Dog| 72| Golden Retriever|Female| Cream|No Color| Large| Medium| Yes| Not Sure| Healthy| 0| 1| 1| 0.7397226|\n",
- "| Cat| 2| Domestic Short Hair| Male| Black| White| Small| Short| No| Yes| Healthy| 0| 4| 1| 1.4016482|\n",
- "| Dog| 3| Irish Terrier| Male| Brown| Cream| Medium| Medium| Yes| No|Minor Injury|200| 3| 0| 1.3436754|\n",
- "| Dog| 2| Mixed Breed|Female| White|No Color| Medium| Medium| Yes| No| Healthy| 0| 1| 1| 1.156439|\n",
- "| Dog| 2| Mixed Breed| Male| Brown|No Color| Medium| Medium| Yes| No| Healthy| 0| 4| 1| 1.7760799|\n",
- "| Cat| 2| Domestic Short Hair| Male| Black| Gray| Medium| Short| No| No| Healthy| 0| 2| 1| 1.8319463|\n",
- "| Dog| 1| German Shepherd Dog| Male| Gray|No Color| Medium| Medium| No| No| Healthy| 0| 6| 1| 2.5471144|\n",
- "| Dog| 24| Golden Retriever| Male|Yellow| Cream| Medium| Long| Yes| Yes| Healthy| 0| 7| 1| 1.4675076|\n",
- "| Dog| 1| Mixed Breed|Female| Black| Brown| Small| Medium| No| Yes| Healthy| 0| 1| 0| 0.8451028|\n",
- "| Cat| 12| Tuxedo| Male| Black|No Color| Small| Medium| Yes| Yes| Healthy| 50| 1| 1| 0.6487097|\n",
- "| Cat| 3| Domestic Short Hair|Female| Black|No Color| Small| Short| No| No| Healthy| 0| 1| 1| 1.0688435|\n",
- "| Dog| 2| Mixed Breed|Female| Brown| White| Medium| Short| No| No| Healthy| 0| 1| 1| 1.4086031|\n",
- "| Dog| 11| Mixed Breed|Female|Golden|No Color| Medium| Short| Yes| Yes| Healthy| 0| 9| 1| 0.28429908|\n",
+ "| Dog| 3| Mixed Breed| Male| Black|No Color| Small| Medium| Not Sure| Not Sure| Healthy| 0| 2| 0| 0.39443576|\n",
+ "| Dog| 9| Mixed Breed| Male| Gray|No Color| Medium| Short| Not Sure| No| Healthy| 0| 4| 1| 0.7342956|\n",
+ "| Cat| 4| Domestic Short Hair| Male| Black| Gray| Medium| Short| Not Sure| Not Sure| Healthy| 0| 4| 1| 0.47965235|\n",
+ "| Cat| 6| Domestic Short Hair| Male|Yellow| White| Medium| Short| No| No| Healthy| 0| 3| 1| 1.0879391|\n",
+ "| Cat| 6|Domestic Medium Hair| Male| Gray|No Color| Small| Medium| No| No| Healthy| 0| 4| 1| 1.0963857|\n",
+ "| Cat| 5|Domestic Medium Hair|Female| Gray|No Color| Medium| Medium| Yes| Not Sure| Healthy| 0| 1| 0| -0.2900008|\n",
+ "| Dog| 24| Beagle|Female| Black| Golden| Medium| Short| Not Sure| Not Sure|Minor Injury| 0| 1| 1| -0.4026273|\n",
+ "| Cat| 29| Tabby| Male| Brown| Golden| Medium| Short| No| No| Healthy| 0| 1| 0| 0.5199422|\n",
+ "| Dog| 9| Mixed Breed|Female| Black| Brown| Medium| Short| Yes| Yes| Healthy| 0| 2| 0|-0.13055041|\n",
+ "| Dog| 2| Mixed Breed|Female| Cream| White| Medium| Short| No| No| Healthy| 0| 1| 0| 1.5426368|\n",
+ "| Dog| 2| Mixed Breed| Male| Brown| White| Medium| Short| Yes| No| Healthy| 0| 1| 1| 1.3627914|\n",
+ "| Dog| 60| Golden Retriever| Male| Brown| Yellow| Medium| Medium| Yes| Yes| Healthy| 0| 5| 1| 0.76058775|\n",
+ "| Cat| 9| Siamese| Male| White|No Color| Medium| Short| Yes| No| Healthy| 0| 2| 1| 1.1733786|\n",
+ "| Dog| 19| Doberman Pinscher|Female| Black| Brown| Large| Short| Yes| Yes| Healthy|500| 2| 1| 0.72322303|\n",
+ "| Cat| 11| Domestic Short Hair| Male| Cream|No Color| Medium| Short| Yes| Yes| Healthy|100| 6| 0| 0.447542|\n",
+ "| Dog| 18| Mixed Breed|Female| Brown| White| Small| Short| Yes| No| Healthy| 0| 5| 0| 0.6399308|\n",
+ "| Dog| 4| Mixed Breed|Female| Brown| White| Medium| Medium| Not Sure| Not Sure| Healthy| 0| 3| 0|0.027931072|\n",
+ "| Dog| 96| Golden Retriever| Male|Golden|No Color| Large| Long| Yes| Yes| Healthy| 0| 2| 1| 0.79285514|\n",
+ "| Dog| 54| Golden Retriever| Male|Golden|No Color| Large| Medium| Yes| No| Healthy|350| 20| 1| 3.0215485|\n",
+ "| Cat| 5|Domestic Medium Hair|Female| Brown| White| Medium| Medium| No| No| Healthy| 0| 5| 1| 1.0509686|\n",
"+----+---+--------------------+------+------+--------+------------+---------+----------+----------+------------+---+--------+------+-----------+\n",
"only showing top 20 rows\n",
"\n"
@@ -1807,14 +1973,17 @@
},
{
"cell_type": "code",
- "execution_count": 59,
+ "execution_count": 62,
"id": "6914f44f-677f-4db3-be09-783df8d11b8a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -1828,31 +1997,39 @@
"[True]"
]
},
- "execution_count": 59,
+ "execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
+ " \n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 60,
+ "execution_count": null,
"id": "f8c6ee43-8891-4446-986e-1447c5d48bac",
"metadata": {},
"outputs": [],
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_resnet50_tf.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_resnet50_tf.ipynb
new file mode 100644
index 000000000..8933dd6c0
--- /dev/null
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/keras_resnet50_tf.ipynb
@@ -0,0 +1,1507 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "8e6810cc-5982-4293-bfbd-c91ef0aca204",
+ "metadata": {},
+ "source": [
+ "\n",
+ "\n",
+ "# PySpark Tensorflow Inference\n",
+ "\n",
+ "### Flower Recognition with Keras Resnet50\n",
+ "\n",
+ "In this notebook, we demonstrate distribute inference with Resnet50 on the Databricks flower photos dataset. \n",
+ "From: https://docs.databricks.com/_static/notebooks/deep-learning/keras-metadata.html"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "858e3a8d",
+ "metadata": {},
+ "source": [
+ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "cf329ac8-0763-44bc-b0f6-b634b7dc480e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-01-10 22:19:53.462313: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2025-01-10 22:19:53.469394: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "2025-01-10 22:19:53.477238: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "2025-01-10 22:19:53.479534: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2025-01-10 22:19:53.485978: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2025-01-10 22:19:53.848582: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ ]
+ }
+ ],
+ "source": [
+ "import os\n",
+ "import shutil\n",
+ "import subprocess\n",
+ "import time\n",
+ "import json\n",
+ "import pandas as pd\n",
+ "from PIL import Image\n",
+ "import numpy as np\n",
+ "import uuid\n",
+ " \n",
+ "import tensorflow as tf\n",
+ "from tensorflow.keras.applications.resnet50 import ResNet50"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "532d562d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "os.mkdir('models') if not os.path.exists('models') else None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "75175140",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2.17.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(tf.__version__)\n",
+ "\n",
+ "# Enable GPU memory growth\n",
+ "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ "if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "02fe61b8",
+ "metadata": {},
+ "source": [
+ "## PySpark"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "b474339c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from pyspark.sql.functions import col, struct, pandas_udf, PandasUDFType\n",
+ "from pyspark.ml.functions import predict_batch_udf\n",
+ "from pyspark.sql.types import *\n",
+ "from pyspark.sql import SparkSession\n",
+ "from pyspark import SparkConf\n",
+ "from typing import Iterator, Tuple"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e182cacb",
+ "metadata": {},
+ "source": [
+ "Check the cluster environment to handle any platform-specific Spark configurations."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "564b1d33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
+ "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
+ "on_standalone = not (on_databricks or on_dataproc)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "016cdd0b",
+ "metadata": {},
+ "source": [
+ "#### Create Spark Session\n",
+ "\n",
+ "For local standalone clusters, we'll connect to the cluster and create the Spark Session. \n",
+ "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "44d72768",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/10 22:19:55 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
+ "25/01/10 22:19:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
+ "Setting default log level to \"WARN\".\n",
+ "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
+ "25/01/10 22:19:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
+ ]
+ }
+ ],
+ "source": [
+ "conf = SparkConf()\n",
+ "\n",
+ "if 'spark' not in globals():\n",
+ " if on_standalone:\n",
+ " import socket\n",
+ " \n",
+ " conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
+ " hostname = socket.gethostname()\n",
+ " conf.setMaster(f\"spark://{hostname}:7077\")\n",
+ " conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
+ " conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_env}/lib:{conda_env}/lib/python3.11/site-packages/nvidia_pytriton.libs:$LD_LIBRARY_PATH\")\n",
+ " source = \"/usr/lib/x86_64-linux-gnu/libstdc++.so.6\"\n",
+ " target = f\"{conda_env}/lib/libstdc++.so.6\"\n",
+ " try:\n",
+ " if os.path.islink(target) or os.path.exists(target):\n",
+ " os.remove(target)\n",
+ " os.symlink(source, target)\n",
+ " except OSError as e:\n",
+ " print(f\"Error creating symlink: {e}\")\n",
+ " elif on_dataproc:\n",
+ " # Point PyTriton to correct libpython3.11.so:\n",
+ " conda_lib_path=\"/opt/conda/miniconda3/lib\"\n",
+ " conf.set(\"spark.executorEnv.LD_LIBRARY_PATH\", f\"{conda_lib_path}:$LD_LIBRARY_PATH\") \n",
+ " conf.set(\"spark.executorEnv.TF_GPU_ALLOCATOR\", \"cuda_malloc_async\")\n",
+ " conf.set(\"spark.executor.instances\", \"4\") # dataproc defaults to 2\n",
+ "\n",
+ " conf.set(\"spark.executor.cores\", \"8\")\n",
+ " conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
+ " conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
+ " conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
+ " conf.set(\"spark.python.worker.reuse\", \"true\")\n",
+ " conf.set(\"spark.driver.memory\", \"8g\")\n",
+ " conf.set(\"spark.executor.memory\", \"8g\")\n",
+ "\n",
+ "conf.set(\"spark.sql.execution.arrow.maxRecordsPerBatch\", \"512\")\n",
+ "conf.set(\"spark.sql.parquet.columnarReaderBatchSize\", \"1024\")\n",
+ "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
+ "sc = spark.sparkContext"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "61c406fa",
+ "metadata": {},
+ "source": [
+ "Define the input and output directories."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "c566dc17",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "os.mkdirs(\"spark-dl-datasets\") if not os.path.exists(\"spark-dl-datasets\") else None\n",
+ "data_path = \"spark-dl-datasets/flowers_{uuid}.parquet\".format(uuid=str(uuid.uuid1()))\n",
+ "local_file_path = f\"{os.getcwd()}/{data_path}\"\n",
+ "output_file_path = \"predictions/predictions\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "968d08a7-66b9-444f-b362-d8df692aef1c",
+ "metadata": {},
+ "source": [
+ "### Prepare trained model and data for inference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "da083168-137f-492c-8769-d8f1e2111756",
+ "metadata": {},
+ "source": [
+ "Load the ResNet-50 Model and broadcast the weights."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "2ddc715a-cdbc-4c49-93e9-58c9d88511da",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2025-01-10 22:19:56.188374: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46469 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = ResNet50()\n",
+ "bc_model_weights = sc.broadcast(model.get_weights())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "77dddfa3-e8df-4e8e-8251-64457f1ebf80",
+ "metadata": {},
+ "source": [
+ "Load the data and save the datasets to one Parquet file."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "c0738bec-97d4-4946-8c49-5e6d07ff1afc",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Image count: 3670\n"
+ ]
+ }
+ ],
+ "source": [
+ "import pathlib\n",
+ "dataset_url = \"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\"\n",
+ "data_dir = tf.keras.utils.get_file(origin=dataset_url,\n",
+ " fname='flower_photos',\n",
+ " untar=True)\n",
+ "data_dir = pathlib.Path(data_dir)\n",
+ "image_count = len(list(data_dir.glob('*/*.jpg')))\n",
+ "print(f\"Image count: {image_count}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "d54f470a-d308-4426-8ed0-33f95155bb4f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(data_dir) for f in filenames if os.path.splitext(f)[1] == '.jpg']\n",
+ "files = files[:2048]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "64f94ee0-f1ea-47f6-a77e-be8da5d1b87a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "image_data = []\n",
+ "for file in files:\n",
+ " img = Image.open(file)\n",
+ " img = img.resize([224, 224])\n",
+ " data = np.asarray(img, dtype=\"float32\").reshape([224*224*3])\n",
+ "\n",
+ " image_data.append({\"data\": data})\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "b4ae1a98",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pd.DataFrame(image_data, columns=['data']).to_parquet(data_path)\n",
+ "\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
+ " shutil.copy(local_file_path, \"/dbfs/FileStore/{}\".format(data_path))\n",
+ " data_path = \"/dbfs/FileStore/{}\".format(data_path)\n",
+ "elif on_dataproc:\n",
+ " data_dir = \"/mnt/gcs/spark-dl/spark-dl-datasets\"\n",
+ " os.mkdir(data_dir) if not os.path.exists(data_dir) else None\n",
+ " shutil.copy(local_file_path, \"/mnt/gcs/spark-dl/\" + data_path)\n",
+ " data_path = \"file:///mnt/gcs/spark-dl/\" + data_path"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f2414b0f-58f2-4e4a-9d09-8ea95b38d413",
+ "metadata": {},
+ "source": [
+ "### Save Model\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "670328e3-7274-4d78-b315-487750166a3f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_path = 'models/resnet50_model.keras'\n",
+ "model.save(model_path)\n",
+ "\n",
+ "# For cloud environments, copy the model to the distributed file system.\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
+ " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/resnet50_model.keras\"\n",
+ " shutil.copy(model_path, dbfs_model_path)\n",
+ " model_path = dbfs_model_path\n",
+ "elif on_dataproc:\n",
+ " # GCS is mounted at /mnt/gcs by the init script\n",
+ " models_dir = \"/mnt/gcs/spark-dl/models\"\n",
+ " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
+ " gcs_model_path = models_dir + \"/resnet50_model.keras\"\n",
+ " shutil.copy(model_path, gcs_model_path)\n",
+ " model_path = gcs_model_path"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b827ad56-1af0-41b7-be68-94bd203a2a70",
+ "metadata": {},
+ "source": [
+ "### Load the data into Spark DataFrames"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "8ddc22d0-b88a-4906-bd47-bf247e34feeb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "2048\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = spark.read.parquet(data_path)\n",
+ "print(df.count())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "865929b0-b016-4de4-996d-7f16176cf49c",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "### Model inference via Pandas UDF"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b1f5a747",
+ "metadata": {},
+ "source": [
+ "Define the function to parse the input data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "a67b3128-13c1-44f1-a0c0-7cf7a836fee3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def parse_image(image_data):\n",
+ " image = tf.image.convert_image_dtype(\n",
+ " image_data, dtype=tf.float32) * (2. / 255) - 1\n",
+ " image = tf.reshape(image, [224, 224, 3])\n",
+ " return image"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "024e4ba2",
+ "metadata": {},
+ "source": [
+ "Define the function for model inference."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "7b33185f-6d1e-4ca9-9757-fdc3d736496b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@pandas_udf(ArrayType(FloatType()))\n",
+ "def pandas_predict_udf(iter: Iterator[Tuple[pd.Series]]) -> Iterator[pd.Series]:\n",
+ "\n",
+ " # Enable GPU memory growth to avoid CUDA OOM\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ "\n",
+ " batch_size = 64\n",
+ " model = ResNet50(weights=None)\n",
+ " model.set_weights(bc_model_weights.value)\n",
+ " for image_batch in iter:\n",
+ " images = np.vstack(image_batch)\n",
+ " dataset = tf.data.Dataset.from_tensor_slices(images)\n",
+ " dataset = dataset.map(parse_image, num_parallel_calls=8).prefetch(\n",
+ " 5000).batch(batch_size)\n",
+ " preds = model.predict(dataset)\n",
+ " yield pd.Series(list(preds))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "08190547",
+ "metadata": {},
+ "source": [
+ "Run model inference and save the results to Parquet."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "ad8c05da-db38-45ef-81d0-1f862f575ced",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 4:===================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 39.5 ms, sys: 28.8 ms, total: 68.3 ms\n",
+ "Wall time: 15.3 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions_1 = df.select(pandas_predict_udf(col(\"data\")).alias(\"prediction\"))\n",
+ "results = predictions_1.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "08cb2a10",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 6:============================================> (3 + 1) / 4]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| prediction|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|[1.2938889E-4, 2.4666305E-4, 6.765791E-5, 1.2263245E-4, 5.7486624E-5, 3.9616702E-4, 7.0566134E-6,...|\n",
+ "|[4.4501914E-5, 3.5403698E-4, 4.6702033E-5, 8.102543E-5, 3.1704556E-5, 1.9194305E-4, 7.905952E-6, ...|\n",
+ "|[1.05672516E-4, 2.2686279E-4, 3.0055395E-5, 6.523785E-5, 2.352077E-5, 3.7122983E-4, 3.3315896E-6,...|\n",
+ "|[2.0331638E-5, 2.2746396E-4, 7.828012E-5, 6.986782E-5, 4.705316E-5, 9.80732E-5, 5.561918E-6, 2.35...|\n",
+ "|[1.130241E-4, 2.3187004E-4, 5.296914E-5, 1.0871329E-4, 4.027478E-5, 3.7183522E-4, 5.5931855E-6, 3...|\n",
+ "|[9.094467E-5, 2.06384E-4, 4.514821E-5, 7.665891E-5, 3.2262324E-5, 3.3875552E-4, 3.831814E-6, 4.18...|\n",
+ "|[1.07847634E-4, 3.7848807E-4, 7.660533E-5, 1.2446754E-4, 4.7595917E-5, 3.333814E-4, 1.0669675E-5,...|\n",
+ "|[2.2261223E-5, 2.734666E-4, 3.8122747E-5, 6.2266954E-5, 1.7935155E-5, 1.7268128E-4, 6.034271E-6, ...|\n",
+ "|[1.1065645E-4, 2.900581E-4, 4.2585547E-5, 1.074203E-4, 3.052314E-5, 4.794604E-4, 6.4872897E-6, 3....|\n",
+ "|[9.673917E-5, 2.058331E-4, 7.4652424E-5, 1.1323769E-4, 4.6106186E-5, 2.8604185E-4, 5.62365E-6, 5....|\n",
+ "|[7.411196E-5, 3.291524E-4, 1.3454164E-4, 1.7738447E-4, 8.467504E-5, 2.2466244E-4, 1.3621126E-5, 1...|\n",
+ "|[8.721524E-5, 2.7338538E-4, 3.5964815E-5, 7.792533E-5, 2.3559302E-5, 3.6789547E-4, 3.5665628E-6, ...|\n",
+ "|[9.723709E-5, 2.7619812E-4, 5.7464153E-5, 1.10104906E-4, 3.8317143E-5, 3.490506E-4, 6.1553183E-6,...|\n",
+ "|[6.940235E-5, 2.5377885E-4, 5.057188E-5, 1.1485363E-4, 3.0059196E-5, 2.7862669E-4, 5.024019E-6, 5...|\n",
+ "|[4.2095784E-5, 2.4891715E-4, 1.236292E-4, 1.4306813E-4, 7.3354306E-5, 1.6047148E-4, 7.958807E-6, ...|\n",
+ "|[2.7327887E-5, 3.8553146E-4, 1.2939748E-4, 1.5762268E-4, 7.307493E-5, 8.5530424E-5, 1.2648808E-5,...|\n",
+ "|[3.036101E-5, 3.5572305E-4, 1.600718E-4, 2.1437313E-4, 8.063033E-5, 1.02061334E-4, 1.3876456E-5, ...|\n",
+ "|[3.3109587E-5, 2.8182982E-4, 1.7998899E-4, 2.0246049E-4, 1.3720036E-4, 1.01000114E-4, 3.427488E-5...|\n",
+ "|[4.549448E-5, 2.8782588E-4, 2.3703449E-4, 2.448979E-4, 1.20997625E-4, 1.3744453E-4, 1.62803E-5, 2...|\n",
+ "|[1.2242574E-4, 2.8095162E-4, 6.332559E-5, 1.0209269E-4, 4.335324E-5, 3.906304E-4, 8.205706E-6, 6....|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_1.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "40799f8e-443e-40ca-919b-391f901cb3f4",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_1.write.mode(\"overwrite\").parquet(output_file_path + \"_1\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e7a69aa9",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "dda88b46-6300-4bf7-bc10-7403f4fbbf92",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict_batch_fn():\n",
+ " import tensorflow as tf\n",
+ " from tensorflow.keras.applications.resnet50 import ResNet50\n",
+ "\n",
+ " # Enable GPU memory growth\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ "\n",
+ " model = ResNet50()\n",
+ " def predict(inputs):\n",
+ " inputs = inputs * (2. / 255) - 1\n",
+ " return model.predict(inputs)\n",
+ " return predict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "cff0e851-563d-40b6-9d05-509c22b3b7f9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "classify = predict_batch_udf(predict_batch_fn,\n",
+ " input_tensor_shapes=[[224, 224, 3]],\n",
+ " return_type=ArrayType(FloatType()),\n",
+ " batch_size=50)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "aa7c156f-e2b3-4837-9427-ccf3a5720412",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = spark.read.parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "80bc50ad-eaf5-4fce-a354-5e17d65e2da5",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 9:===================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 68.6 ms, sys: 27.9 ms, total: 96.5 ms\n",
+ "Wall time: 16.1 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# first pass caches model/fn\n",
+ "predictions_2 = df.select(classify(struct(\"data\")).alias(\"prediction\"))\n",
+ "results = predictions_2.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "41cace80-7a4b-4929-8e63-9c83f9745e02",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 10:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 134 ms, sys: 34.6 ms, total: 168 ms\n",
+ "Wall time: 14.9 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions_2 = df.select(classify(\"data\").alias(\"prediction\"))\n",
+ "results = predictions_2.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "56a2ec8a-de09-4d7c-9666-1b3c76f10657",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 11:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 55.3 ms, sys: 21.7 ms, total: 77 ms\n",
+ "Wall time: 9.57 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions_2 = df.select(classify(col(\"data\")).alias(\"prediction\"))\n",
+ "results = predictions_2.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "2dcf3791",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 13:===========================================> (3 + 1) / 4]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| prediction|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|[1.2938889E-4, 2.4666305E-4, 6.765791E-5, 1.2263245E-4, 5.7486624E-5, 3.9616702E-4, 7.0566134E-6,...|\n",
+ "|[4.4501914E-5, 3.5403698E-4, 4.6702033E-5, 8.102543E-5, 3.1704556E-5, 1.9194305E-4, 7.905952E-6, ...|\n",
+ "|[1.05672516E-4, 2.2686279E-4, 3.0055395E-5, 6.523785E-5, 2.352077E-5, 3.7122983E-4, 3.3315896E-6,...|\n",
+ "|[2.0331638E-5, 2.2746396E-4, 7.828012E-5, 6.986782E-5, 4.705316E-5, 9.80732E-5, 5.561918E-6, 2.35...|\n",
+ "|[1.130241E-4, 2.3187004E-4, 5.296914E-5, 1.0871329E-4, 4.027478E-5, 3.7183522E-4, 5.5931855E-6, 3...|\n",
+ "|[9.094467E-5, 2.06384E-4, 4.514821E-5, 7.665891E-5, 3.2262324E-5, 3.3875552E-4, 3.831814E-6, 4.18...|\n",
+ "|[1.07847634E-4, 3.7848807E-4, 7.660533E-5, 1.2446754E-4, 4.7595917E-5, 3.333814E-4, 1.0669675E-5,...|\n",
+ "|[2.2261223E-5, 2.734666E-4, 3.8122747E-5, 6.2266954E-5, 1.7935155E-5, 1.7268128E-4, 6.034271E-6, ...|\n",
+ "|[1.1065645E-4, 2.900581E-4, 4.2585547E-5, 1.074203E-4, 3.052314E-5, 4.794604E-4, 6.4872897E-6, 3....|\n",
+ "|[9.673917E-5, 2.058331E-4, 7.4652424E-5, 1.1323769E-4, 4.6106186E-5, 2.8604185E-4, 5.62365E-6, 5....|\n",
+ "|[7.411196E-5, 3.291524E-4, 1.3454164E-4, 1.7738447E-4, 8.467504E-5, 2.2466244E-4, 1.3621126E-5, 1...|\n",
+ "|[8.721524E-5, 2.7338538E-4, 3.5964815E-5, 7.792533E-5, 2.3559302E-5, 3.6789547E-4, 3.5665628E-6, ...|\n",
+ "|[9.723709E-5, 2.7619812E-4, 5.7464153E-5, 1.10104906E-4, 3.8317143E-5, 3.490506E-4, 6.1553183E-6,...|\n",
+ "|[6.940235E-5, 2.5377885E-4, 5.057188E-5, 1.1485363E-4, 3.0059196E-5, 2.7862669E-4, 5.024019E-6, 5...|\n",
+ "|[4.2095784E-5, 2.4891715E-4, 1.236292E-4, 1.4306813E-4, 7.3354306E-5, 1.6047148E-4, 7.958807E-6, ...|\n",
+ "|[2.7327887E-5, 3.8553146E-4, 1.2939748E-4, 1.5762268E-4, 7.307493E-5, 8.5530424E-5, 1.2648808E-5,...|\n",
+ "|[3.036101E-5, 3.5572305E-4, 1.600718E-4, 2.1437313E-4, 8.063033E-5, 1.02061334E-4, 1.3876456E-5, ...|\n",
+ "|[3.3109587E-5, 2.8182982E-4, 1.7998899E-4, 2.0246049E-4, 1.3720036E-4, 1.01000114E-4, 3.427488E-5...|\n",
+ "|[4.549448E-5, 2.8782588E-4, 2.3703449E-4, 2.448979E-4, 1.20997625E-4, 1.3744453E-4, 1.62803E-5, 2...|\n",
+ "|[1.2242574E-4, 2.8095162E-4, 6.332559E-5, 1.0209269E-4, 4.335324E-5, 3.906304E-4, 8.205706E-6, 6....|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_2.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "fc511eae",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_2.write.mode(\"overwrite\").parquet(output_file_path + \"_2\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "878ca7fb",
+ "metadata": {},
+ "source": [
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
+ "\n",
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "2605d134-ef75-4d94-9b16-2c6d85f29bef",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "8c8c0744-0558-4dac-bbfe-8bdde4b2af2d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_server(ports):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import tensorflow as tf\n",
+ " from tensorflow.keras.applications import ResNet50\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
+ "\n",
+ " print(f\"SERVER: Initializing ResNet on worker {TaskContext.get().partitionId()}.\")\n",
+ "\n",
+ " # Enable GPU memory growth\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ " \n",
+ " model = ResNet50()\n",
+ " normalization_layer = tf.keras.layers.Rescaling(scale=2./255, offset=-1)\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " images = inputs[\"images\"]\n",
+ " normalized_images = normalization_layer(images)\n",
+ " return {\n",
+ " \"preds\": model.predict(normalized_images),\n",
+ " }\n",
+ "\n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"ResNet50\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"images\", dtype=np.float32, shape=(224, 224, 3)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=100,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d74f7037",
+ "metadata": {},
+ "source": [
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "13196ae8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "44a387dc",
+ "metadata": {},
+ "source": [
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "132fbfed",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "2309a55c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "533e2a89",
+ "metadata": {},
+ "source": [
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "825370dd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "dfc8834a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Using ports [7000, 7001, 7002]\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_name = \"ResNet50\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "ad24bc52",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 15:> (0 + 1) / 1]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2604934\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e49ebdbe",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "aa34bebb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "url = f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "a5ab49bb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
+ " import numpy as np\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " print(f\"CLIENT: Connecting to {model_name} at {url}\")\n",
+ "\n",
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " result_data = client.infer_batch(inputs)\n",
+ " return result_data[\"preds\"]\n",
+ " \n",
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fcd2328e",
+ "metadata": {},
+ "source": [
+ "#### Load DataFrame"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "bbfc9009",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = spark.read.parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8c07365c-0a14-49b3-9bd8-cfb35f48b089",
+ "metadata": {},
+ "source": [
+ "#### Run inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "id": "9fabcaeb-5a44-42bb-8097-5dbc2d0cee3e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial\n",
+ "\n",
+ "classify = predict_batch_udf(partial(triton_fn, url=url, model_name=model_name),\n",
+ " input_tensor_shapes=[[224, 224, 3]],\n",
+ " return_type=ArrayType(FloatType()),\n",
+ " batch_size=50)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "e595473d-1a5d-46a6-a6ba-89d2ea903de9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 17:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 53 ms, sys: 22.7 ms, total: 75.8 ms\n",
+ "Wall time: 18.3 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "# first pass caches model/fn\n",
+ "predictions_3 = df.select(classify(struct(\"data\")).alias(\"prediction\"))\n",
+ "results = predictions_3.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "5f66d468-e0b1-4589-8606-b3848063a823",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 18:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 48.6 ms, sys: 26.3 ms, total: 74.8 ms\n",
+ "Wall time: 11.7 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions_3 = df.select(classify(\"data\").alias(\"prediction\"))\n",
+ "results = predictions_3.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "632c4c3a-fa52-4c3d-b71e-7526286e353a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 19:==================================================> (7 + 1) / 8]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 60.4 ms, sys: 18.1 ms, total: 78.5 ms\n",
+ "Wall time: 11.9 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions_3 = df.select(classify(col(\"data\")).alias(\"prediction\"))\n",
+ "results = predictions_3.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "49870e39",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 21:===========================================> (3 + 1) / 4]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "| prediction|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "|[1.2938889E-4, 2.4666305E-4, 6.765791E-5, 1.2263245E-4, 5.7486624E-5, 3.9616702E-4, 7.0566134E-6,...|\n",
+ "|[4.4501914E-5, 3.5403698E-4, 4.6702033E-5, 8.102543E-5, 3.1704556E-5, 1.9194305E-4, 7.905952E-6, ...|\n",
+ "|[1.05672516E-4, 2.2686279E-4, 3.0055395E-5, 6.523785E-5, 2.352077E-5, 3.7122983E-4, 3.3315896E-6,...|\n",
+ "|[2.0331638E-5, 2.2746396E-4, 7.828012E-5, 6.986782E-5, 4.705316E-5, 9.80732E-5, 5.561918E-6, 2.35...|\n",
+ "|[1.130241E-4, 2.3187004E-4, 5.296914E-5, 1.0871329E-4, 4.027478E-5, 3.7183522E-4, 5.5931855E-6, 3...|\n",
+ "|[9.094467E-5, 2.06384E-4, 4.514821E-5, 7.665891E-5, 3.2262324E-5, 3.3875552E-4, 3.831814E-6, 4.18...|\n",
+ "|[1.07847634E-4, 3.7848807E-4, 7.660533E-5, 1.2446754E-4, 4.7595917E-5, 3.333814E-4, 1.0669675E-5,...|\n",
+ "|[2.2261223E-5, 2.734666E-4, 3.8122747E-5, 6.2266954E-5, 1.7935155E-5, 1.7268128E-4, 6.034271E-6, ...|\n",
+ "|[1.1065645E-4, 2.900581E-4, 4.2585547E-5, 1.074203E-4, 3.052314E-5, 4.794604E-4, 6.4872897E-6, 3....|\n",
+ "|[9.673917E-5, 2.058331E-4, 7.4652424E-5, 1.1323769E-4, 4.6106186E-5, 2.8604185E-4, 5.62365E-6, 5....|\n",
+ "|[7.411196E-5, 3.291524E-4, 1.3454164E-4, 1.7738447E-4, 8.467504E-5, 2.2466244E-4, 1.3621126E-5, 1...|\n",
+ "|[8.721524E-5, 2.7338538E-4, 3.5964815E-5, 7.792533E-5, 2.3559302E-5, 3.6789547E-4, 3.5665628E-6, ...|\n",
+ "|[9.723709E-5, 2.7619812E-4, 5.7464153E-5, 1.10104906E-4, 3.8317143E-5, 3.490506E-4, 6.1553183E-6,...|\n",
+ "|[6.940235E-5, 2.5377885E-4, 5.057188E-5, 1.1485363E-4, 3.0059196E-5, 2.7862669E-4, 5.024019E-6, 5...|\n",
+ "|[4.2095784E-5, 2.4891715E-4, 1.236292E-4, 1.4306813E-4, 7.3354306E-5, 1.6047148E-4, 7.958807E-6, ...|\n",
+ "|[2.7327887E-5, 3.8553146E-4, 1.2939748E-4, 1.5762268E-4, 7.307493E-5, 8.5530424E-5, 1.2648808E-5,...|\n",
+ "|[3.036101E-5, 3.5572305E-4, 1.600718E-4, 2.1437313E-4, 8.063033E-5, 1.02061334E-4, 1.3876456E-5, ...|\n",
+ "|[3.3109587E-5, 2.8182982E-4, 1.7998899E-4, 2.0246049E-4, 1.3720036E-4, 1.01000114E-4, 3.427488E-5...|\n",
+ "|[4.549448E-5, 2.8782588E-4, 2.3703449E-4, 2.448979E-4, 1.20997625E-4, 1.3744453E-4, 1.62803E-5, 2...|\n",
+ "|[1.2242574E-4, 2.8095162E-4, 6.332559E-5, 1.0209269E-4, 4.335324E-5, 3.906304E-4, 8.205706E-6, 6....|\n",
+ "+----------------------------------------------------------------------------------------------------+\n",
+ "only showing top 20 rows\n",
+ "\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_3.show(truncate=100)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "86cd59f9",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ }
+ ],
+ "source": [
+ "predictions_3.write.mode(\"overwrite\").parquet(output_file_path + \"_3\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4dc06b7e-f750-40b5-9208-a035db11d937",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "#### Stop Triton Server on each executor"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "bbfcaa51-3b9f-43ff-a4a8-4b46766115b8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "[True]"
+ ]
+ },
+ "execution_count": 45,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
+ " \n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
+ " \n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
+ "\n",
+ " return [False]\n",
+ "\n",
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "0d88639b-d934-4eb4-ae2f-cc13b9b10456",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "spark.stop()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "df8cc28a-34d7-479c-be7e-9a380d39e25e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "spark-dl-tf",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/feature_columns/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/feature_columns/1/model.py
deleted file mode 100644
index a2fa9635a..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/feature_columns/1/model.py
+++ /dev/null
@@ -1,162 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-import tensorflow as tf
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- print("tf: {}".format(tf.__version__))
- gpus = tf.config.list_physical_devices('GPU')
- if gpus:
- try:
- # Currently, memory growth needs to be the same across GPUs
- for gpu in gpus:
- tf.config.experimental.set_memory_growth(gpu, True)
- logical_gpus = tf.config.list_logical_devices('GPU')
- print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
- except RuntimeError as e:
- # Memory growth must be set before GPUs have been initialized
- print(e)
-
- self.model = tf.keras.models.load_model("/my_pet_classifier.keras")
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- pred_config = pb_utils.get_output_config_by_name(model_config, "pred")
-
- # Convert Triton types to numpy types
- self.pred_dtype = pb_utils.triton_string_to_numpy(pred_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- pred_dtype = self.pred_dtype
-
- responses = []
-
- def decode(input_tensor):
- return tf.convert_to_tensor([[s[0].decode('utf-8')] for s in input_tensor.as_numpy()])
-
- def identity(input_tensor):
- return tf.convert_to_tensor(input_tensor.as_numpy())
-
- input_transforms = {
- "Type": decode,
- "Age": identity,
- "Breed1": decode,
- "Gender": decode,
- "Color1": decode,
- "Color2": decode,
- "MaturitySize": decode,
- "FurLength": decode,
- "Vaccinated": decode,
- "Sterilized": decode,
- "Health": decode,
- "Fee": identity,
- "PhotoAmt": identity
- }
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- inputs = {name: transform(pb_utils.get_input_tensor_by_name(request, name)) for name, transform in input_transforms.items()}
-
- pred = self.model.predict(inputs, verbose=0)
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- pred_tensor = pb_utils.Tensor("pred", np.squeeze(pred).astype(pred_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[pred_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/feature_columns/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/feature_columns/config.pbtxt
deleted file mode 100644
index 93a7cf045..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/feature_columns/config.pbtxt
+++ /dev/null
@@ -1,111 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "feature_columns"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "Type"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Age"
- data_type: TYPE_INT64
- dims: [1]
- },
- {
- name: "Breed1"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Gender"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Color1"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Color2"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "MaturitySize"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "FurLength"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Vaccinated"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Sterilized"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Health"
- data_type: TYPE_STRING
- dims: [1]
- },
- {
- name: "Fee"
- data_type: TYPE_FP32
- dims: [1]
- },
- {
- name: "PhotoAmt"
- data_type: TYPE_FP32
- dims: [1]
- }
-]
-output [
- {
- name: "pred"
- data_type: TYPE_FP32
- dims: [1]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../tf-gpu.tar.gz"}
-}
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/mnist_model/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/mnist_model/config.pbtxt
deleted file mode 100644
index cc9172f45..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/mnist_model/config.pbtxt
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-platform: "tensorflow_savedmodel"
-max_batch_size: 8192
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/resnet50/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/resnet50/config.pbtxt
deleted file mode 100644
index cc9172f45..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/resnet50/config.pbtxt
+++ /dev/null
@@ -1,17 +0,0 @@
-# Copyright (c) 2024, NVIDIA CORPORATION.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-platform: "tensorflow_savedmodel"
-max_batch_size: 8192
-
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/text_classification/1/model.py b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/text_classification/1/model.py
deleted file mode 100644
index 1bdef0b9c..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/text_classification/1/model.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-import numpy as np
-import json
-import tensorflow as tf
-
-# triton_python_backend_utils is available in every Triton Python model. You
-# need to use this module to create inference requests and responses. It also
-# contains some utility functions for extracting information from model_config
-# and converting Triton input/output types to numpy types.
-import triton_python_backend_utils as pb_utils
-
-
-class TritonPythonModel:
- """Your Python model must use the same class name. Every Python model
- that is created must have "TritonPythonModel" as the class name.
- """
-
- def initialize(self, args):
- """`initialize` is called only once when the model is being loaded.
- Implementing `initialize` function is optional. This function allows
- the model to intialize any state associated with this model.
-
- Parameters
- ----------
- args : dict
- Both keys and values are strings. The dictionary keys and values are:
- * model_config: A JSON string containing the model configuration
- * model_instance_kind: A string containing model instance kind
- * model_instance_device_id: A string containing model instance device ID
- * model_repository: Model repository path
- * model_version: Model version
- * model_name: Model name
- """
- import re
- import string
- from tensorflow.keras import layers
-
- print("tf: {}".format(tf.__version__))
-
- def custom_standardization(input_data):
- lowercase = tf.strings.lower(input_data)
- stripped_html = tf.strings.regex_replace(lowercase, " ", " ")
- return tf.strings.regex_replace(
- stripped_html, "[%s]" % re.escape(string.punctuation), ""
- )
-
- max_features = 10000
- sequence_length = 250
-
- vectorize_layer = layers.TextVectorization(
- standardize=custom_standardization,
- max_tokens=max_features,
- output_mode="int",
- output_sequence_length=sequence_length,
- )
-
- custom_objects = {"vectorize_layer": vectorize_layer,
- "custom_standardization": custom_standardization}
- with tf.keras.utils.custom_object_scope(custom_objects):
- self.model = tf.keras.models.load_model(
- "/text_model_cleaned.keras", compile=False
- )
-
- # You must parse model_config. JSON string is not parsed here
- self.model_config = model_config = json.loads(args['model_config'])
-
- # Get output configuration
- pred_config = pb_utils.get_output_config_by_name(model_config, "pred")
-
- # Convert Triton types to numpy types
- self.pred_dtype = pb_utils.triton_string_to_numpy(pred_config['data_type'])
-
- def execute(self, requests):
- """`execute` MUST be implemented in every Python model. `execute`
- function receives a list of pb_utils.InferenceRequest as the only
- argument. This function is called when an inference request is made
- for this model. Depending on the batching configuration (e.g. Dynamic
- Batching) used, `requests` may contain multiple requests. Every
- Python model, must create one pb_utils.InferenceResponse for every
- pb_utils.InferenceRequest in `requests`. If there is an error, you can
- set the error argument when creating a pb_utils.InferenceResponse
-
- Parameters
- ----------
- requests : list
- A list of pb_utils.InferenceRequest
-
- Returns
- -------
- list
- A list of pb_utils.InferenceResponse. The length of this list must
- be the same as `requests`
- """
-
- pred_dtype = self.pred_dtype
-
- responses = []
-
- # Every Python backend must iterate over everyone of the requests
- # and create a pb_utils.InferenceResponse for each of them.
- for request in requests:
- # Get input numpy
- sentence_input = pb_utils.get_input_tensor_by_name(request, "sentence")
- sentences = sentence_input.as_numpy()
- sentences = np.squeeze(sentences).tolist()
- sentences = [s.decode('utf-8') for s in sentences]
- sentences = tf.convert_to_tensor(sentences)
-
- pred = self.model.predict(sentences, verbose=0)
-
- # Create output tensors. You need pb_utils.Tensor
- # objects to create pb_utils.InferenceResponse.
- pred_tensor = pb_utils.Tensor("pred", pred.astype(pred_dtype))
-
- # Create InferenceResponse. You can set an error here in case
- # there was a problem with handling this inference request.
- # Below is an example of how you can set errors in inference
- # response:
- #
- # pb_utils.InferenceResponse(
- # output_tensors=..., TritonError("An error occured"))
- inference_response = pb_utils.InferenceResponse(output_tensors=[pred_tensor])
- responses.append(inference_response)
-
- # You should return a list of pb_utils.InferenceResponse. Length
- # of this list must match the length of `requests` list.
- return responses
-
- def finalize(self):
- """`finalize` is called only once when the model is being unloaded.
- Implementing `finalize` function is OPTIONAL. This function allows
- the model to perform any necessary clean ups before exit.
- """
- print('Cleaning up...')
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/text_classification/config.pbtxt b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/text_classification/config.pbtxt
deleted file mode 100644
index 44b21bf39..000000000
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/models_config/text_classification/config.pbtxt
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions
-# are met:
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above copyright
-# notice, this list of conditions and the following disclaimer in the
-# documentation and/or other materials provided with the distribution.
-# * Neither the name of NVIDIA CORPORATION nor the names of its
-# contributors may be used to endorse or promote products derived
-# from this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
-# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
-# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
-# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
-# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
-# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
-# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
-# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-name: "text_classification"
-backend: "python"
-max_batch_size: 8192
-
-input [
- {
- name: "sentence"
- data_type: TYPE_STRING
- dims: [1]
- }
-]
-output [
- {
- name: "pred"
- data_type: TYPE_FP32
- dims: [1]
- }
-]
-
-instance_group [{ kind: KIND_GPU }]
-
-parameters: {
- key: "EXECUTION_ENV_PATH",
- value: {string_value: "$$TRITON_MODEL_DIRECTORY/../tf-gpu.tar.gz"}
-}
diff --git a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/text_classification_tf.ipynb b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/text_classification_tf.ipynb
index 63499611c..09c203192 100644
--- a/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/text_classification_tf.ipynb
+++ b/examples/ML+DL-Examples/Spark-DL/dl_inference/tensorflow/text_classification_tf.ipynb
@@ -5,9 +5,12 @@
"id": "2cd2accf-5877-4136-a243-7a33a13ce2b4",
"metadata": {},
"source": [
+ "\n",
+ "\n",
"# Pyspark TensorFlow Inference\n",
"\n",
- "## Text classification\n",
+ "### Text Classification\n",
+ "In this notebook, we demonstrate training a model to perform sentiment analysis, and using the trained model for distributed inference. \n",
"Based on: https://www.tensorflow.org/tutorials/keras/text_classification"
]
},
@@ -16,9 +19,7 @@
"id": "bc72d0ed",
"metadata": {},
"source": [
- "### Using TensorFlow\n",
- "Note that cuFFT/cuDNN/cuBLAS registration errors are expected with `tf=2.17.0` and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) \n",
- "This notebook does not demonstrate inference with TensorRT, as [TF-TRT](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html#tensorrt-10) does not yet support `tf=2.17.0`. See the `pytorch` notebooks for TensorRT demos."
+ "Note that cuFFT/cuDNN/cuBLAS registration errors are expected (as of `tf=2.17.0`) and will not affect behavior, as noted in [this issue.](https://github.com/tensorflow/tensorflow/issues/62075) "
]
},
{
@@ -31,13 +32,13 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-24 16:15:43.020721: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
- "2024-10-24 16:15:43.028070: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
- "2024-10-24 16:15:43.035674: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
- "2024-10-24 16:15:43.037910: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
- "2024-10-24 16:15:43.044256: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "2025-01-07 17:55:03.625173: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n",
+ "2025-01-07 17:55:03.632499: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "2025-01-07 17:55:03.640392: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "2025-01-07 17:55:03.642797: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2025-01-07 17:55:03.648973: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
- "2024-10-24 16:15:43.368732: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ "2025-01-07 17:55:04.012978: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
@@ -46,8 +47,8 @@
"import re\n",
"import shutil\n",
"import string\n",
- "\n",
"import matplotlib.pyplot as plt\n",
+ "\n",
"import tensorflow as tf\n",
"from tensorflow.keras import layers, losses"
]
@@ -67,16 +68,8 @@
}
],
"source": [
- "print(tf.__version__)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "57b1d71f",
- "metadata": {},
- "outputs": [],
- "source": [
+ "print(tf.__version__)\n",
+ "\n",
"# Enable GPU memory growth\n",
"gpus = tf.config.experimental.list_physical_devices('GPU')\n",
"if gpus:\n",
@@ -88,131 +81,91 @@
]
},
{
- "cell_type": "code",
- "execution_count": 10,
- "id": "d229c1b6-3967-46b5-9ea8-68f4b42dd211",
+ "cell_type": "markdown",
+ "id": "b64bb471",
"metadata": {},
- "outputs": [],
"source": [
- "import pathlib\n",
- "url = \"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
- "\n",
- "dataset = tf.keras.utils.get_file(\n",
- " fname=\"aclImdb\", origin=url, untar=True,\n",
- ")\n",
- "\n",
- "dataset_dir = pathlib.Path(dataset)"
+ "### Download and explore the dataset"
]
},
{
"cell_type": "code",
- "execution_count": 11,
- "id": "bfa5177f",
+ "execution_count": 3,
+ "id": "d229c1b6-3967-46b5-9ea8-68f4b42dd211",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "/home/rishic/.keras/datasets/aclImdb\n",
- "/home/rishic/.keras/datasets/aclImdb\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "print(dataset_dir)\n",
- "# aclImdb might be created as a directory containing a single directory aclImdb. Check if this is the case:\n",
- "if os.path.exists(dataset_dir / \"aclImdb\"):\n",
- " dataset_dir = dataset_dir / \"aclImdb\"\n",
- "print(dataset_dir)"
+ "from datasets import load_dataset\n",
+ "dataset = load_dataset(\"imdb\")"
]
},
{
"cell_type": "code",
- "execution_count": 6,
- "id": "1f8038ae-8bc1-46bf-ae4c-6da08886c473",
+ "execution_count": 4,
+ "id": "88f9a92e",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['README', 'imdb.vocab', 'test', 'train', 'imdbEr.txt']"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "os.listdir(dataset_dir)"
+ "# Create directories for our data\n",
+ "base_dir = \"spark-dl-datasets/imdb\"\n",
+ "if os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False):\n",
+ " # For databricks, use the driver disk rather than Workspace (much faster)\n",
+ " base_dir = \"/local_disk0/\" + base_dir\n",
+ "\n",
+ "train_dir = base_dir + \"/train\"\n",
+ "test_dir = base_dir + \"/test\""
]
},
{
"cell_type": "code",
- "execution_count": 13,
- "id": "12faaa3f-3441-4361-b9eb-4317e8c2c2f7",
+ "execution_count": 5,
+ "id": "3f984d5a",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "['pos',\n",
- " 'labeledBow.feat',\n",
- " 'urls_pos.txt',\n",
- " 'neg',\n",
- " 'urls_unsup.txt',\n",
- " 'unsupBow.feat',\n",
- " 'urls_neg.txt',\n",
- " 'unsup']"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "train_dir = os.path.join(dataset_dir, \"train\")\n",
- "test_dir = os.path.join(dataset_dir, \"test\")\n",
- "os.listdir(train_dir)"
+ "# Create directories for positive (1) and negative (0) reviews\n",
+ "for split in [\"train\", \"test\"]:\n",
+ " split_dir = os.path.join(base_dir, split)\n",
+ " pos_dir = split_dir + \"/pos\"\n",
+ " neg_dir = split_dir + \"/neg\"\n",
+ "\n",
+ " os.makedirs(pos_dir, exist_ok=True)\n",
+ " os.makedirs(neg_dir, exist_ok=True)"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "152cc0cc-65d0-4e17-9ee8-222390df45b5",
+ "execution_count": 6,
+ "id": "6cd2328a",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Rachel Griffiths writes and directs this award winning short film. A heartwarming story about coping with grief and cherishing the memory of those we've loved and lost. Although, only 15 minutes long, Griffiths manages to capture so much emotion and truth onto film in the short space of time. Bud Tingwell gives a touching performance as Will, a widower struggling to cope with his wife's death. Will is confronted by the harsh reality of loneliness and helplessness as he proceeds to take care of Ruth's pet cow, Tulip. The film displays the grief and responsibility one feels for those they have loved and lost. Good cinematography, great direction, and superbly acted. It will bring tears to all those who have lost a loved one, and survived.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "sample_file = os.path.join(train_dir, \"pos/1181_9.txt\")\n",
- "with open(sample_file) as f:\n",
- " print(f.read())"
+ "def write_reviews_to_files(dataset_split, split_name):\n",
+ " for idx, example in enumerate(dataset_split):\n",
+ " label_dir = \"pos\" if example[\"label\"] == 1 else \"neg\"\n",
+ " dir_path = os.path.join(base_dir, split_name, label_dir)\n",
+ "\n",
+ " file_path = dir_path + f\"/review_{idx}.txt\"\n",
+ " with open(file_path, \"w\", encoding=\"utf-8\") as f:\n",
+ " f.write(example[\"text\"])\n",
+ "\n",
+ "# Write train and test sets\n",
+ "write_reviews_to_files(dataset[\"train\"], \"train\")\n",
+ "write_reviews_to_files(dataset[\"test\"], \"test\")"
]
},
{
- "cell_type": "code",
- "execution_count": 15,
- "id": "b2277f58-78c8-4a12-bc98-5103e7c81a35",
+ "cell_type": "markdown",
+ "id": "b02fde64",
"metadata": {},
- "outputs": [],
"source": [
- "remove_dir = os.path.join(train_dir, \"unsup\")\n",
- "shutil.rmtree(remove_dir)"
+ "There are 25,000 examples in the training folder, of which we will use 80% (or 20,000) for training, and 5,000 for validation."
]
},
{
"cell_type": "code",
- "execution_count": 17,
- "id": "ed83de92-ebb3-4170-b2bf-25265c6a6942",
+ "execution_count": 7,
+ "id": "5c357f22",
"metadata": {},
"outputs": [
{
@@ -227,7 +180,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-24 02:18:45.343343: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 46446 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
+ "2025-01-07 17:55:15.035387: I tensorflow/core/common_runtime/gpu/gpu_device.cc:2021] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 45468 MB memory: -> device: 0, name: NVIDIA RTX A6000, pci bus id: 0000:01:00.0, compute capability: 8.6\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Found 25000 files belonging to 2 classes.\n",
+ "Using 5000 files for validation.\n",
+ "Found 25000 files belonging to 2 classes.\n"
]
}
],
@@ -236,29 +198,50 @@
"seed = 42\n",
"\n",
"raw_train_ds = tf.keras.utils.text_dataset_from_directory(\n",
- " train_dir,\n",
+ " str(train_dir),\n",
" batch_size=batch_size,\n",
" validation_split=0.2,\n",
" subset=\"training\",\n",
" seed=seed,\n",
+ ")\n",
+ "\n",
+ "raw_val_ds = tf.keras.utils.text_dataset_from_directory(\n",
+ " str(train_dir),\n",
+ " batch_size=batch_size,\n",
+ " validation_split=0.2,\n",
+ " subset=\"validation\",\n",
+ " seed=seed,\n",
+ ")\n",
+ "\n",
+ "raw_test_ds = tf.keras.utils.text_dataset_from_directory(\n",
+ " str(test_dir),\n",
+ " batch_size=batch_size\n",
")"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "02994994",
+ "metadata": {},
+ "source": [
+ "We can take a look at a sample of the dataset (note that OUT_OF_RANGE errors are safe to ignore: https://github.com/tensorflow/tensorflow/issues/62963):"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 11,
- "id": "57c30568-daa8-4b2b-b30a-577c984a8af5",
+ "execution_count": 8,
+ "id": "1d528a95",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Review b'\"Pandemonium\" is a horror movie spoof that comes off more stupid than funny. Believe me when I tell you, I love comedies. Especially comedy spoofs. \"Airplane\", \"The Naked Gun\" trilogy, \"Blazing Saddles\", \"High Anxiety\", and \"Spaceballs\" are some of my favorite comedies that spoof a particular genre. \"Pandemonium\" is not up there with those films. Most of the scenes in this movie had me sitting there in stunned silence because the movie wasn\\'t all that funny. There are a few laughs in the film, but when you watch a comedy, you expect to laugh a lot more than a few times and that\\'s all this film has going for it. Geez, \"Scream\" had more laughs than this film and that was more of a horror film. How bizarre is that?
*1/2 (out of four)'\n",
+ "Review b'I was really, really disappointed with this movie. it started really well, and built up some great atmosphere and suspense, but when it finally got round to revealing the \"monster\"...it turned out to be just some psycho with skin problems......again. Whoop-de-do. Yet another nutjob movie...like we don\\'t already have enough of them.
To be fair, the \"creep\" is genuinely unsettling to look at, and the way he moves and the strange sounds he makes are pretty creepy, but I\\'m sick of renting film like this only to discover that the monster is human, albeit a twisted, demented, freakish one. When I saw all the tell-tale rats early on I was hoping for some kind of freaky rat-monster hybrid thing...it was such a let down when the Creep was revealed.
On top of this, some of the stuff in this movie makes no sense. (Spoiler)
Why the hell does the Creep kill the security Guard? Whats the point, apart from sticking a great honking sign up that says \"HI I\\'m A PSYCHO AND I LIVE DOWN HERE!\"? Its stupid, and only seems to happen to prevent Franka Potente\\'s character from getting help.
what the hells he been eating down there? I got the impression he was effectively walled in, and only the unexpected opening into that tunnel section let him loose...so has he been munching rats all that time, and if so why do they hang around him so much? Why is he so damn hard to kill? He\\'s thin, malnourished and not exactly at peak performance...but seems to keep going despite injuries that are equivalent to those that .cripple the non-psycho characters in the film.
The DVD commentary says we are intended to empathise with Creep, but I just find him loathsome. Its an effective enough movie, but it wasted so many opportunities that it makes me sick.'\n",
"Label 0\n",
- "Review b\"David Mamet is a very interesting and a very un-equal director. His first movie 'House of Games' was the one I liked best, and it set a series of films with characters whose perspective of life changes as they get into complicated situations, and so does the perspective of the viewer.
So is 'Homicide' which from the title tries to set the mind of the viewer to the usual crime drama. The principal characters are two cops, one Jewish and one Irish who deal with a racially charged area. The murder of an old Jewish shop owner who proves to be an ancient veteran of the Israeli Independence war triggers the Jewish identity in the mind and heart of the Jewish detective.
This is were the flaws of the film are the more obvious. The process of awakening is theatrical and hard to believe, the group of Jewish militants is operatic, and the way the detective eventually walks to the final violent confrontation is pathetic. The end of the film itself is Mamet-like smart, but disappoints from a human emotional perspective.
Joe Mantegna and William Macy give strong performances, but the flaws of the story are too evident to be easily compensated.\"\n",
+ "Review b\"This has the absolute worst performance from Robert Duval who sounds just like William Buckley throughout the entire film. His hammy melodramatic acting takes away from any dramatic interest. I'm not sure if this was deliberate scene stealing or inadvertent but it's the only thing I can recall from a truly forgettable film. This picture should be shown in every amateur acting class of an example of what not to do. Thank God, Duvall went on to bigger and better things and stopped trying to effect a cultured accent. He is a good character actor but that's about it. Klaus is so much better. His performance is muted and noteworthy.\"\n",
"Label 0\n",
- "Review b'Great documentary about the lives of NY firefighters during the worst terrorist attack of all time.. That reason alone is why this should be a must see collectors item.. What shocked me was not only the attacks, but the\"High Fat Diet\" and physical appearance of some of these firefighters. I think a lot of Doctors would agree with me that,in the physical shape they were in, some of these firefighters would NOT of made it to the 79th floor carrying over 60 lbs of gear. Having said that i now have a greater respect for firefighters and i realize becoming a firefighter is a life altering job. The French have a history of making great documentary\\'s and that is what this is, a Great Documentary.....'\n",
+ "Review b'A long time ago, in a galaxy far, far away.....There was a boy who was only two years old when the original \"Star Wars\" film was released. He doesn\\'t remember first seeing the movie, but he also doesn\\'t remember life before it. He does remember the first \"Star Wars\" themed gift he got...a shoebox full of action figures from the original set. He was too young to fully appreciate how special that gift would be. But years later, he would get what to this day goes down as one of the best gifts he\\'s ever received: another box full of action figures, ten of the final twelve he needed to complete his collection. It\\'s now legendary in this boy\\'s family how the last action figure he needed, Anakin Skywalker, stopped being produced and carried in stores, and how this boy went for about ten years (until he got into college) trying to track one down and finally bought it from someone on his dorm floor for a bag of beer nuggets (don\\'t ask...it\\'s a Northern Illinois University thing).
I can\\'t review \"Star Wars\" as a movie. It represents absolutely everything good, fun and magical about my childhood. There\\'s no separating it in my mind from Christmases, birthdays, summers and winters growing up. In the winter, my friends and I would build snow forts and pretend we were on Hoth (I was always Han Solo). My friends\\' dad built them a kick-ass tree house, and that served as the Ewok village. They also had a huge pine tree whose bottom branches were high enough to create a sort of cave underneath it, and this made a great spot to pretend we were in Yoda\\'s home. I am unabashedly dorky when it comes to \"Star Wars\" and I think people either just understand that or they don\\'t. I don\\'t get the appeal of \"Lord of the Rings\" or \"Star Trek\" but I understand the rabid flocks of fans that follow them because I am a rabid fan of George Lucas\\'s films.
I feel no need to defend my opinion of these movies as some of the greatest of all time. Every time I put them in the DVD player, I feel like I\\'m eight years old again, when life was simple and the biggest problem I had was figuring out how I was going to track down a figure of Anakin Skywalker.
Grade (for the entire trilogy): A+'\n",
"Label 1\n"
]
},
@@ -266,7 +249,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:44:08.132892: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
+ "2025-01-07 17:55:21.572943: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
]
}
],
@@ -278,73 +261,47 @@
]
},
{
- "cell_type": "code",
- "execution_count": 12,
- "id": "1e863eb6-4bd7-4da0-b10d-d951b5ee52bd",
+ "cell_type": "markdown",
+ "id": "4bca98b1",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Label 0 corresponds to neg\n",
- "Label 1 corresponds to pos\n"
- ]
- }
- ],
"source": [
- "print(\"Label 0 corresponds to\", raw_train_ds.class_names[0])\n",
- "print(\"Label 1 corresponds to\", raw_train_ds.class_names[1])"
+ "Notice the reviews contain raw text (with punctuation and occasional HTML tags like \\ \\). We will show how to handle these in the following section."
]
},
{
"cell_type": "code",
- "execution_count": 13,
- "id": "1593e2e5-df51-4fbf-b4be-c786e740ddab",
+ "execution_count": 9,
+ "id": "f8921ed2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Found 25000 files belonging to 2 classes.\n",
- "Using 5000 files for validation.\n"
+ "Label 0 corresponds to neg\n",
+ "Label 1 corresponds to pos\n"
]
}
],
"source": [
- "raw_val_ds = tf.keras.utils.text_dataset_from_directory(\n",
- " train_dir,\n",
- " batch_size=batch_size,\n",
- " validation_split=0.2,\n",
- " subset=\"validation\",\n",
- " seed=seed,\n",
- ")"
+ "print(\"Label 0 corresponds to\", raw_train_ds.class_names[0])\n",
+ "print(\"Label 1 corresponds to\", raw_train_ds.class_names[1])"
]
},
{
- "cell_type": "code",
- "execution_count": 14,
- "id": "944fd61d-3926-4296-889a-b2a375a1b039",
+ "cell_type": "markdown",
+ "id": "f6cf0e47",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 25000 files belonging to 2 classes.\n"
- ]
- }
- ],
"source": [
- "raw_test_ds = tf.keras.utils.text_dataset_from_directory(\n",
- " test_dir, batch_size=batch_size\n",
- ")"
+ "### Prepare the dataset for training\n",
+ "\n",
+ "Next, we will standardize, tokenize, and vectorize the data using the tf.keras.layers.TextVectorization layer. \n",
+ "We will write a custom standardization function to remove the HTML."
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 10,
"id": "cb141709-fcc1-4cee-bc98-9c89aaba8648",
"metadata": {},
"outputs": [],
@@ -357,9 +314,17 @@
" )"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "b35e36a2",
+ "metadata": {},
+ "source": [
+ "Next, we will create a TextVectorization layer to standardize, tokenize, and vectorize our data."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 11,
"id": "d4e80ea9-536a-4ebc-8b35-1eca73dbba7d",
"metadata": {},
"outputs": [],
@@ -375,9 +340,17 @@
")"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "879fbc3f",
+ "metadata": {},
+ "source": [
+ "Next, we will call adapt to fit the state of the preprocessing layer to the dataset. This will cause the model to build an index of strings to integers."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 12,
"id": "ad1e5d81-7dae-4b08-b520-ca45501b9510",
"metadata": {},
"outputs": [
@@ -385,7 +358,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2024-10-03 17:44:10.225130: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
+ "2025-01-07 17:55:35.387277: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n"
]
}
],
@@ -395,9 +368,17 @@
"vectorize_layer.adapt(train_text)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "ad1e5d81-7dae-4b08-b520-ca45501b9510",
+ "metadata": {},
+ "source": [
+ "Let's create a function to see the result of using this layer to preprocess some data."
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 13,
"id": "80f243f5-edd3-4e1c-bddc-abc1cc6673ef",
"metadata": {},
"outputs": [],
@@ -409,7 +390,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 14,
"id": "8f37e95c-515c-4edb-a1ee-fc47be5df4b9",
"metadata": {},
"outputs": [
@@ -417,32 +398,32 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "Review tf.Tensor(b'Silent Night, Deadly Night 5 is the very last of the series, and like part 4, it\\'s unrelated to the first three except by title and the fact that it\\'s a Christmas-themed horror flick.
Except to the oblivious, there\\'s some obvious things going on here...Mickey Rooney plays a toymaker named Joe Petto and his creepy son\\'s name is Pino. Ring a bell, anyone? Now, a little boy named Derek heard a knock at the door one evening, and opened it to find a present on the doorstep for him. Even though it said \"don\\'t open till Christmas\", he begins to open it anyway but is stopped by his dad, who scolds him and sends him to bed, and opens the gift himself. Inside is a little red ball that sprouts Santa arms and a head, and proceeds to kill dad. Oops, maybe he should have left well-enough alone. Of course Derek is then traumatized by the incident since he watched it from the stairs, but he doesn\\'t grow up to be some killer Santa, he just stops talking.
There\\'s a mysterious stranger lurking around, who seems very interested in the toys that Joe Petto makes. We even see him buying a bunch when Derek\\'s mom takes him to the store to find a gift for him to bring him out of his trauma. And what exactly is this guy doing? Well, we\\'re not sure but he does seem to be taking these toys apart to see what makes them tick. He does keep his landlord from evicting him by promising him to pay him in cash the next day and presents him with a \"Larry the Larvae\" toy for his kid, but of course \"Larry\" is not a good toy and gets out of the box in the car and of course, well, things aren\\'t pretty.
Anyway, eventually what\\'s going on with Joe Petto and Pino is of course revealed, and as with the old story, Pino is not a \"real boy\". Pino is probably even more agitated and naughty because he suffers from \"Kenitalia\" (a smooth plastic crotch) so that could account for his evil ways. And the identity of the lurking stranger is revealed too, and there\\'s even kind of a happy ending of sorts. Whee.
A step up from part 4, but not much of one. Again, Brian Yuzna is involved, and Screaming Mad George, so some decent special effects, but not enough to make this great. A few leftovers from part 4 are hanging around too, like Clint Howard and Neith Hunter, but that doesn\\'t really make any difference. Anyway, I now have seeing the whole series out of my system. Now if I could get some of it out of my brain. 4 out of 5.', shape=(), dtype=string)\n",
+ "Review tf.Tensor(b\"To describe this film as garbage is unfair. At least rooting through garbage can be an absorbing hobby. This flick was neither absorbing nor entertaining.
Kevin Bacon can act superbly given the chance, so no doubt had an IRS bill to settle when he agreed to this dire screenplay. The mad scientist story of 'Hollow Man' has been told before, been told better, and been told without resorting to so many ludicrously expensive special effects.
Most of those special effects seem to be built around the transparent anatomical dolls of men, women and dogs you could buy in the early seventies. In the UK they were marketed as 'The Transparent Man (/Woman/Dog)' which is maybe where they got the title for this film.
Clever special effects, dire script, non-existent plot.
The show used to be awesome, however, Bam's fame and wealth has led him to believe, that we now enjoy him acting childish and idiotic, more than actual cool stuff, that used to be in ex. CKY2K.
The acts are so repetitive, there's like nothing new, except annoying stupidity and rehearsed comments... The only things we see is Bam Margera, so busy showing us how much he doesn't care, how much money he got or whatsoever.
I really got nothing much left to say except, give us back CKY2K, cause Bam suck..
I enjoy watching Steve-o, Knoxville etc. a thousand times more.\")]"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.take(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "9d9db063",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 22:02:36 WARN TaskSetManager: Stage 3 contains a task of very large size (4021 KiB). The maximum recommended task size is 1000 KiB.\n"
+ ]
+ }
+ ],
+ "source": [
+ "data_path = \"spark-dl-datasets/imdb_test\"\n",
+ "if on_databricks:\n",
+ " data_path = \"dbfs:/FileStore/\" + data_path\n",
+ "\n",
+ "df.write.mode(\"overwrite\").parquet(data_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2f78a16a",
+ "metadata": {},
+ "source": [
+ "#### Load and Preprocess PySpark DataFrame\n",
+ "\n",
+ "Define our preprocess function. We'll take the first sentence of each sample as our input for sentiment analysis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "1c081557",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@pandas_udf(\"string\")\n",
+ "def preprocess(text: pd.Series) -> pd.Series:\n",
+ " return pd.Series([s.split(\".\")[0] for s in text])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "60af570a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = spark.read.parquet(data_path).limit(512).repartition(8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "id": "a690f6df",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input_df = df.select(preprocess(col(\"text\")).alias(\"lines\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "01166d97",
+ "metadata": {},
+ "source": [
+ "## Inference using Spark DL API\n",
+ "\n",
+ "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/3.4.0/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
+ "\n",
+ "- predict_batch_fn uses Tensorflow APIs to load the model and return a predict function which operates on numpy arrays \n",
+ "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
"id": "7b7a8395-e2ae-4c3c-bf57-763dfde600ad",
"metadata": {},
"outputs": [],
"source": [
- "text_model_path = \"{}/text_model.keras\".format(os.getcwd())"
+ "text_model_path = \"{}/models/text_model.keras\".format(os.getcwd())\n",
+ "\n",
+ "# For cloud environments, copy the model to the distributed file system.\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
+ " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/text_model.keras\"\n",
+ " shutil.copy(text_model_path, dbfs_model_path)\n",
+ " text_model_path = dbfs_model_path\n",
+ "elif on_dataproc:\n",
+ " # GCS is mounted at /mnt/gcs by the init script\n",
+ " models_dir = \"/mnt/gcs/spark-dl/models\"\n",
+ " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
+ " gcs_model_path = models_dir + \"/text_model.keras\"\n",
+ " shutil.copy(text_model_path, gcs_model_path)\n",
+ " text_model_path = gcs_model_path"
]
},
{
"cell_type": "code",
- "execution_count": 41,
+ "execution_count": 46,
"id": "8c0524cf-3a75-4fb8-8025-f0654acce13e",
"metadata": {},
"outputs": [],
@@ -1206,7 +1456,7 @@
},
{
"cell_type": "code",
- "execution_count": 42,
+ "execution_count": 47,
"id": "0d603644-d938-4c87-aa8a-2512251638d5",
"metadata": {},
"outputs": [],
@@ -1218,7 +1468,7 @@
},
{
"cell_type": "code",
- "execution_count": 43,
+ "execution_count": 48,
"id": "0b480622-8dc1-4879-933e-c43112768630",
"metadata": {},
"outputs": [
@@ -1226,97 +1476,76 @@
"name": "stderr",
"output_type": "stream",
"text": [
- " \r"
+ "[Stage 9:> (0 + 8) / 8]\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 22.3 ms, sys: 6.39 ms, total: 28.7 ms\n",
- "Wall time: 5.95 s\n"
+ "CPU times: user 5.29 ms, sys: 4.43 ms, total: 9.73 ms\n",
+ "Wall time: 4.3 s\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
]
}
],
"source": [
"%%time\n",
- "predictions = df.withColumn(\"preds\", classify(struct(\"lines\")))\n",
+ "predictions = input_df.withColumn(\"preds\", classify(struct(\"lines\")))\n",
"results = predictions.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 44,
+ "execution_count": 49,
"id": "31b0a262-387e-4a5e-a60e-b9b8ee456199",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 5:==============================================> (8 + 2) / 10]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 98.3 ms, sys: 8.08 ms, total: 106 ms\n",
- "Wall time: 1.24 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 4.94 ms, sys: 0 ns, total: 4.94 ms\n",
+ "Wall time: 150 ms\n"
]
}
],
"source": [
"%%time\n",
- "predictions = df.withColumn(\"preds\", classify(\"lines\"))\n",
+ "predictions = input_df.withColumn(\"preds\", classify(\"lines\"))\n",
"results = predictions.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 45,
+ "execution_count": 50,
"id": "7ef9e431-59f5-4b29-9f79-ae16a9cfb0b9",
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "[Stage 8:==============================================> (8 + 2) / 10]\r"
- ]
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "CPU times: user 16.1 ms, sys: 4.41 ms, total: 20.6 ms\n",
- "Wall time: 1.18 s\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- " \r"
+ "CPU times: user 2.38 ms, sys: 2.54 ms, total: 4.92 ms\n",
+ "Wall time: 206 ms\n"
]
}
],
"source": [
"%%time\n",
- "predictions = df.withColumn(\"preds\", classify(col(\"lines\")))\n",
+ "predictions = input_df.withColumn(\"preds\", classify(col(\"lines\")))\n",
"results = predictions.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 46,
+ "execution_count": 51,
"id": "9a325ee2-3268-414a-bb75-a5fcf794f512",
"metadata": {
"scrolled": true
@@ -1329,26 +1558,26 @@
"+--------------------------------------------------------------------------------+----------+\n",
"| lines| preds|\n",
"+--------------------------------------------------------------------------------+----------+\n",
- "|i do not understand at all why this movie received such good grades from crit...| 0.5006337|\n",
- "| I am a big fan of The ABC Movies of the Week genre|0.57577586|\n",
- "| Strangeland is a terrible horror/technological thriller| 0.5441176|\n",
- "|Sex,Drugs,Rock & Roll is without a doubt the worst product of Western Civiliz...|0.53261155|\n",
- "| Not to be mistaken as the highly touted Samuel L| 0.5785005|\n",
- "|Following the pleasingly atmospheric original and the amusingly silly second ...|0.51597977|\n",
- "| No idea how this is rated as high as it is (5|0.55052567|\n",
- "|When I saw this in the cinema, I remember wincing at the bad acting about a m...|0.52347463|\n",
- "| I was shocked at how bad it was and unable to turn away from the disaster| 0.5262873|\n",
- "|I don't know if this exceptionally dull movie was intended as an unofficial s...| 0.5175539|\n",
- "|Greetings All,
Isn't it amazing the power that films have on you a...|0.64540815|\n",
- "| I'm sorry but this guy is not funny| 0.5385401|\n",
- "|This movie is so dull I spent half of it on IMDb while it was open in another...| 0.5182078|\n",
- "| OK, lets start with the best| 0.5611213|\n",
- "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...| 0.5557351|\n",
- "|The 3rd in the series finds Paul Kersey (Bronson) turning vigilante to get re...|0.56089103|\n",
- "| I'm not sure I've ever seen a film as bad as this| 0.54292|\n",
- "| Steven Seagal has made a really dull, bad and boring movie| 0.5089991|\n",
- "| You have to acknowledge Cimino's contribution to cinema| 0.5760211|\n",
- "|***SPOILER ALERT*** Disjointed and confusing arson drama that has to do with ...| 0.5469447|\n",
+ "|The only reason I'm even giving this movie a 4 is because it was made in to a...|0.52321863|\n",
+ "|Awkward disaster mishmash has a team of scavengers coming across the overturn...|0.55067354|\n",
+ "|Here is a fantastic concept for a film - a series of meteors crash into a sma...| 0.6197893|\n",
+ "| I walked out of the cinema having suffered this film after 30 mins| 0.5503541|\n",
+ "|A wildly uneven film where the major problem is the uneasy mix of comedy and ...| 0.5540192|\n",
+ "|Leonard Rossiter and Frances de la Tour carry this film, not without a strugg...| 0.5467422|\n",
+ "| A good cast| 0.5688838|\n",
+ "|Yet again, I appear to be the only person on planet Earth who is capable of c...|0.55650306|\n",
+ "|As a serious horror fan, I get that certain marketing ploys are used to sell ...| 0.5629433|\n",
+ "|Upon writing this review I have difficulty trying to think of what to write a...| 0.5383269|\n",
+ "| Simply awful| 0.5275883|\n",
+ "|I am a fan of Ed Harris' work and I really had high expectations about this film|0.55910736|\n",
+ "| Well|0.56994545|\n",
+ "| This is a new approach to comedy| 0.5674365|\n",
+ "| It's been mentioned by others the inane dialogue in this series and I agree|0.55741817|\n",
+ "|One of the most boring movies I've ever had to sit through, it's completely f...| 0.5303776|\n",
+ "|This movie was playing on Lifetime Movie Network last month and I decided to ...| 0.5663204|\n",
+ "| 1983's \"Frightmare\" is an odd little film| 0.560836|\n",
+ "| 'Felony' is a B-movie| 0.5602156|\n",
+ "| This movie defines the word \"confused\"| 0.5535761|\n",
"+--------------------------------------------------------------------------------+----------+\n",
"only showing top 20 rows\n",
"\n"
@@ -1361,79 +1590,35 @@
},
{
"cell_type": "markdown",
- "id": "579b53bf-5a8a-4f85-a5b5-fb82a4be7f06",
+ "id": "ad9b07e6",
"metadata": {},
"source": [
- "### Using Triton Inference Server\n",
+ "## Using Triton Inference Server\n",
+ "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL. \n",
+ "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server. \n",
+ "\n",
+ "The process looks like this:\n",
+ "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
+ "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
+ "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
+ "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
"\n",
- "Note: you can restart the kernel and run from this point to simulate running in a different node or environment."
+ ""
]
},
{
"cell_type": "markdown",
- "id": "8598edb1-acb7-4704-8f0d-20b0f431a323",
+ "id": "889a1623",
"metadata": {},
"source": [
- "This notebook uses the [Python backend with a custom execution environment](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments) for Triton 24.08, using a conda-pack environment created as follows:\n",
- "```\n",
- "conda create -n tf-gpu -c conda-forge python=3.10.0\n",
- "conda activate tf-gpu\n",
- "\n",
- "export PYTHONNOUSERSITE=True\n",
- "pip install numpy==1.26.4 tensorflow[and-cuda] conda-pack\n",
- "\n",
- "conda pack # tf-gpu.tar.gz\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 47,
- "id": "772e337e-1098-4c7b-ba81-8cb221a518e2",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "import numpy as np\n",
- "import os\n",
- "from pyspark.ml.functions import predict_batch_udf\n",
- "from pyspark.sql.functions import col, struct\n",
- "from pyspark.sql.types import ArrayType, FloatType"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 48,
- "id": "69d0c93a-bb0b-46c5-9d28-7b08a2e70964",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
- "outputs": [],
- "source": [
- "%%bash\n",
- "# copy custom model to expected layout for Triton\n",
- "rm -rf models\n",
- "mkdir -p models\n",
- "cp -r models_config/text_classification models\n",
- "\n",
- "# add custom execution environment\n",
- "cp tf-gpu.tar.gz models"
+ "First we'll cleanup the vocabulary layer of the model to remove non-ASCII characters. This ensures the inputs can be properly serialized and sent to Triton."
]
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 52,
"id": "f4f14c8f",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
"import unicodedata\n",
@@ -1454,299 +1639,442 @@
"normalized_vocab = normalize_vocabulary(vocab)\n",
"\n",
"# Reassign the cleaned vocabulary to the TextVectorization layer\n",
- "vectorize_layer.set_vocabulary(normalized_vocab)\n",
- "\n",
- "# Save the model with the cleaned vocabulary\n",
- "export_model.save('text_model_cleaned.keras')"
+ "vectorize_layer.set_vocabulary(normalized_vocab)"
]
},
{
- "cell_type": "markdown",
- "id": "0d8c9ab3-57c4-45bb-9bcf-6433337ef9b5",
+ "cell_type": "code",
+ "execution_count": 53,
+ "id": "9614a192",
"metadata": {},
+ "outputs": [],
"source": [
- "#### Start Triton Server on each executor"
+ "# Save the model with the cleaned vocabulary\n",
+ "triton_model_path = '{}/models/text_model_cleaned.keras'.format(os.getcwd())\n",
+ "export_model.save(triton_model_path)\n",
+ "\n",
+ "# For cloud environments, copy the model to the distributed file system.\n",
+ "if on_databricks:\n",
+ " dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
+ " dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/text_model_cleaned.keras\"\n",
+ " shutil.copy(triton_model_path, dbfs_model_path)\n",
+ " triton_model_path = dbfs_model_path\n",
+ "elif on_dataproc:\n",
+ " # GCS is mounted at /mnt/gcs by the init script\n",
+ " models_dir = \"/mnt/gcs/spark-dl/models\"\n",
+ " os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
+ " gcs_model_path = models_dir + \"/text_model_cleaned.keras\"\n",
+ " shutil.copy(triton_model_path, gcs_model_path)\n",
+ " triton_model_path = gcs_model_path"
]
},
{
"cell_type": "code",
- "execution_count": 50,
- "id": "a7fb146c-5319-4831-85f7-f2f3c084b042",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 54,
+ "id": "32d0142a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from functools import partial"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "id": "a4d37d33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_server(ports, model_path):\n",
+ " import time\n",
+ " import signal\n",
+ " import numpy as np\n",
+ " import tensorflow as tf\n",
+ " from pytriton.decorators import batch\n",
+ " from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
+ " from pytriton.triton import Triton, TritonConfig\n",
+ " from pyspark import TaskContext\n",
+ " from tensorflow.keras import layers \n",
+ "\n",
+ " \n",
+ " print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
+ " # Enable GPU memory growth\n",
+ " gpus = tf.config.experimental.list_physical_devices('GPU')\n",
+ " if gpus:\n",
+ " try:\n",
+ " for gpu in gpus:\n",
+ " tf.config.experimental.set_memory_growth(gpu, True)\n",
+ " except RuntimeError as e:\n",
+ " print(e)\n",
+ "\n",
+ " def custom_standardization(input_data):\n",
+ " lowercase = tf.strings.lower(input_data)\n",
+ " stripped_html = tf.strings.regex_replace(lowercase, \" \", \" \")\n",
+ " return tf.strings.regex_replace(\n",
+ " stripped_html, \"[%s]\" % re.escape(string.punctuation), \"\"\n",
+ " )\n",
+ "\n",
+ " max_features = 10000\n",
+ " sequence_length = 250\n",
+ "\n",
+ " vectorize_layer = layers.TextVectorization(\n",
+ " standardize=custom_standardization,\n",
+ " max_tokens=max_features,\n",
+ " output_mode=\"int\",\n",
+ " output_sequence_length=sequence_length,\n",
+ " )\n",
+ "\n",
+ " custom_objects = {\"vectorize_layer\": vectorize_layer,\n",
+ " \"custom_standardization\": custom_standardization}\n",
+ "\n",
+ " with tf.keras.utils.custom_object_scope(custom_objects):\n",
+ " model = tf.keras.models.load_model(model_path)\n",
+ "\n",
+ " @batch\n",
+ " def _infer_fn(**inputs):\n",
+ " sentences = inputs[\"text\"]\n",
+ " print(f\"SERVER: Received batch of size {len(sentences)}.\")\n",
+ " decoded_sentences = tf.convert_to_tensor(np.vectorize(lambda x: x.decode('utf-8'))(sentences))\n",
+ " return {\n",
+ " \"preds\": model.predict(decoded_sentences)\n",
+ " }\n",
+ " \n",
+ " workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
+ " triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
+ " with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
+ " triton.bind(\n",
+ " model_name=\"TextModel\",\n",
+ " infer_func=_infer_fn,\n",
+ " inputs=[\n",
+ " Tensor(name=\"text\", dtype=np.bytes_, shape=(-1,)),\n",
+ " ],\n",
+ " outputs=[\n",
+ " Tensor(name=\"preds\", dtype=np.float32, shape=(-1,)),\n",
+ " ],\n",
+ " config=ModelConfig(\n",
+ " max_batch_size=128,\n",
+ " batcher=DynamicBatcher(max_queue_delay_microseconds=5000), # 5ms\n",
+ " ),\n",
+ " strict=True,\n",
+ " )\n",
+ "\n",
+ " def stop_triton(signum, frame):\n",
+ " print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
+ " triton.stop()\n",
+ "\n",
+ " signal.signal(signal.SIGTERM, stop_triton)\n",
+ "\n",
+ " print(\"SERVER: Serving inference\")\n",
+ " triton.serve()\n",
+ "\n",
+ "def start_triton(ports, model_name, model_path):\n",
+ " import socket\n",
+ " from multiprocessing import Process\n",
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " hostname = socket.gethostname()\n",
+ " process = Process(target=triton_server, args=(ports, model_path,))\n",
+ " process.start()\n",
+ "\n",
+ " client = ModelClient(f\"http://localhost:{ports[0]}\", model_name)\n",
+ " patience = 10\n",
+ " while patience > 0:\n",
+ " try:\n",
+ " client.wait_for_model(6)\n",
+ " return [(hostname, process.pid)]\n",
+ " except Exception:\n",
+ " print(\"Waiting for server to be ready...\")\n",
+ " patience -= 1\n",
+ "\n",
+ " emsg = \"Failure: client waited too long for server startup. Check the executor logs for more info.\"\n",
+ " raise TimeoutError(emsg)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d340e231",
+ "metadata": {},
+ "source": [
+ "#### Start Triton servers\n",
+ "\n",
+ "To ensure that only one Triton inference server is started per node, we use stage-level scheduling to delegate each task to a separate GPU. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 56,
+ "id": "35a6eac2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def _use_stage_level_scheduling(spark, rdd):\n",
+ "\n",
+ " if spark.version < \"3.4.0\":\n",
+ " raise Exception(\"Stage-level scheduling is not supported in Spark < 3.4.0\")\n",
+ "\n",
+ " executor_cores = spark.conf.get(\"spark.executor.cores\")\n",
+ " assert executor_cores is not None, \"spark.executor.cores is not set\"\n",
+ " executor_gpus = spark.conf.get(\"spark.executor.resource.gpu.amount\")\n",
+ " assert executor_gpus is not None and int(executor_gpus) <= 1, \"spark.executor.resource.gpu.amount must be set and <= 1\"\n",
+ "\n",
+ " from pyspark.resource.profile import ResourceProfileBuilder\n",
+ " from pyspark.resource.requests import TaskResourceRequests\n",
+ "\n",
+ " spark_plugins = spark.conf.get(\"spark.plugins\", \" \")\n",
+ " assert spark_plugins is not None\n",
+ " spark_rapids_sql_enabled = spark.conf.get(\"spark.rapids.sql.enabled\", \"true\")\n",
+ " assert spark_rapids_sql_enabled is not None\n",
+ "\n",
+ " task_cores = (\n",
+ " int(executor_cores)\n",
+ " if \"com.nvidia.spark.SQLPlugin\" in spark_plugins\n",
+ " and \"true\" == spark_rapids_sql_enabled.lower()\n",
+ " else (int(executor_cores) // 2) + 1\n",
+ " )\n",
+ "\n",
+ " task_gpus = 1.0\n",
+ " treqs = TaskResourceRequests().cpus(task_cores).resource(\"gpu\", task_gpus)\n",
+ " rp = ResourceProfileBuilder().require(treqs).build\n",
+ " print(f\"Reqesting stage-level resources: (cores={task_cores}, gpu={task_gpus})\")\n",
+ "\n",
+ " return rdd.withResources(rp)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bad219c9",
+ "metadata": {},
+ "source": [
+ "**Specify the number of nodes in the cluster.** \n",
+ "Following the README, the example standalone cluster uses 1 node. The example Databricks/Dataproc cluster scripts use 4 nodes by default. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 57,
+ "id": "a01c6198",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Change based on cluster setup\n",
+ "num_nodes = 1 if on_standalone else 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 58,
+ "id": "4d5dc419",
+ "metadata": {},
"outputs": [
{
- "name": "stderr",
+ "name": "stdout",
"output_type": "stream",
"text": [
- " \r"
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
]
- },
- {
- "data": {
- "text/plain": [
- "[True]"
- ]
- },
- "execution_count": 50,
- "metadata": {},
- "output_type": "execute_result"
}
],
"source": [
- "num_executors = 1\n",
- "triton_models_dir = \"{}/models\".format(os.getcwd())\n",
- "text_model_dir = \"{}/text_model_cleaned.keras\".format(os.getcwd())\n",
- "nodeRDD = sc.parallelize(list(range(num_executors)), num_executors)\n",
- "\n",
- "def start_triton(it):\n",
- " import docker\n",
- " import time\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " if containers:\n",
- " print(\">>>> containers: {}\".format([c.short_id for c in containers]))\n",
- " else:\n",
- " container=client.containers.run(\n",
- " \"nvcr.io/nvidia/tritonserver:24.08-py3\", \"tritonserver --model-repository=/models\",\n",
- " detach=True,\n",
- " device_requests=[docker.types.DeviceRequest(device_ids=[\"0\"], capabilities=[['gpu']])],\n",
- " name=\"spark-triton\",\n",
- " network_mode=\"host\",\n",
- " remove=True,\n",
- " shm_size=\"128M\",\n",
- " volumes={\n",
- " triton_models_dir: {\"bind\": \"/models\", \"mode\": \"ro\"},\n",
- " text_model_dir: {\"bind\": \"/text_model_cleaned.keras\", \"mode\": \"ro\"}\n",
- " }\n",
- " )\n",
- " print(\">>>> starting triton: {}\".format(container.short_id))\n",
- "\n",
- " # wait for triton to be running\n",
- " time.sleep(15)\n",
- " client = grpcclient.InferenceServerClient(\"localhost:8001\")\n",
- " ready = False\n",
- " while not ready:\n",
- " try:\n",
- " ready = client.is_server_ready()\n",
- " except Exception as e:\n",
- " time.sleep(5)\n",
- " \n",
- " return [True]\n",
- "\n",
- "nodeRDD.barrier().mapPartitions(start_triton).collect()"
+ "sc = spark.sparkContext\n",
+ "nodeRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "nodeRDD = _use_stage_level_scheduling(spark, nodeRDD)"
]
},
{
"cell_type": "markdown",
- "id": "287873da-6202-4b55-97fb-cda8644b1fee",
+ "id": "0bdba73f",
"metadata": {},
"source": [
- "#### Run inference"
+ "Triton occupies ports for HTTP requests, GRPC requests, and the metrics service."
]
},
{
"cell_type": "code",
- "execution_count": 51,
- "id": "41106a02-236e-4cb3-ac51-76aa64b663c2",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 59,
+ "id": "013205e3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def find_ports():\n",
+ " import psutil\n",
+ " \n",
+ " ports = []\n",
+ " conns = {conn.laddr.port for conn in psutil.net_connections(kind=\"inet\")}\n",
+ " i = 7000\n",
+ " while len(ports) < 3:\n",
+ " if i not in conns:\n",
+ " ports.append(i)\n",
+ " i += 1\n",
+ " \n",
+ " return ports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 60,
+ "id": "7fa58218",
+ "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+----------------------------------------------------------------------------------------------------+\n",
- "| lines|\n",
- "+----------------------------------------------------------------------------------------------------+\n",
- "|i do not understand at all why this movie received such good grades from critics - - i've seen te...|\n",
- "| I am a big fan of The ABC Movies of the Week genre|\n",
- "| Strangeland is a terrible horror/technological thriller|\n",
- "| Sex,Drugs,Rock & Roll is without a doubt the worst product of Western Civilization|\n",
- "| Not to be mistaken as the highly touted Samuel L|\n",
- "|Following the pleasingly atmospheric original and the amusingly silly second one, this incredibly...|\n",
- "| No idea how this is rated as high as it is (5|\n",
- "|When I saw this in the cinema, I remember wincing at the bad acting about a minute or two into th...|\n",
- "| I was shocked at how bad it was and unable to turn away from the disaster|\n",
- "|I don't know if this exceptionally dull movie was intended as an unofficial sequel to 'The French...|\n",
- "|Greetings All,
Isn't it amazing the power that films have on you after the 1st viewing...|\n",
- "| I'm sorry but this guy is not funny|\n",
- "|This movie is so dull I spent half of it on IMDb while it was open in another tab on Netflix tryi...|\n",
- "| OK, lets start with the best|\n",
- "|This show had a promising start as sort of the opposite of 'Oceans 11' but has developed into a s...|\n",
- "|The 3rd in the series finds Paul Kersey (Bronson) turning vigilante to get revenge on the thugs t...|\n",
- "| I'm not sure I've ever seen a film as bad as this|\n",
- "| Steven Seagal has made a really dull, bad and boring movie|\n",
- "| You have to acknowledge Cimino's contribution to cinema|\n",
- "|***SPOILER ALERT*** Disjointed and confusing arson drama that has to do with a sinister plan to b...|\n",
- "+----------------------------------------------------------------------------------------------------+\n",
- "only showing top 20 rows\n",
- "\n"
+ "Using ports [7000, 7001, 7002]\n"
]
}
],
"source": [
- "from datasets import load_dataset\n",
- "\n",
- "# load IMDB reviews (test) dataset\n",
- "data = load_dataset(\"imdb\", split=\"test\")\n",
- "lines = []\n",
- "for example in data:\n",
- " lines.append([example[\"text\"].split(\".\")[0]])\n",
- "\n",
- "df = spark.createDataFrame(lines, ['lines']).repartition(10)\n",
- "df.show(truncate=100)"
+ "model_name = \"TextModel\"\n",
+ "ports = find_ports()\n",
+ "assert len(ports) == 3\n",
+ "print(f\"Using ports {ports}\")"
]
},
{
"cell_type": "code",
- "execution_count": 52,
- "id": "8b763167-7f50-4278-9bc9-6c3433b62294",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 61,
+ "id": "bdcf9187",
+ "metadata": {},
"outputs": [
{
- "data": {
- "text/plain": [
- "['lines']"
- ]
- },
- "execution_count": 52,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 19:> (0 + 1) / 1]\r"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Triton Server PIDs:\n",
+ " {\n",
+ " \"cb4ae00-lcedt\": 2897388\n",
+ "}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " \r"
+ ]
}
],
"source": [
- "columns = df.columns\n",
- "columns"
+ "pids = nodeRDD.barrier().mapPartitions(lambda _: start_triton(ports, model_name, triton_model_path)).collectAsMap()\n",
+ "print(\"Triton Server PIDs:\\n\", json.dumps(pids, indent=4))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e1477f4b",
+ "metadata": {},
+ "source": [
+ "#### Define client function"
]
},
{
"cell_type": "code",
- "execution_count": 53,
- "id": "29b0cc0d-c480-4e4a-bd41-207dc314cba5",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "execution_count": 62,
+ "id": "d590cd25",
+ "metadata": {},
"outputs": [],
"source": [
- "def triton_fn(triton_uri, model_name):\n",
+ "url = f\"http://localhost:{ports[0]}\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 63,
+ "id": "0ad47438",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def triton_fn(url, model_name):\n",
" import numpy as np\n",
- " import tritonclient.grpc as grpcclient\n",
- " \n",
- " np_types = {\n",
- " \"BOOL\": np.dtype(np.bool_),\n",
- " \"INT8\": np.dtype(np.int8),\n",
- " \"INT16\": np.dtype(np.int16),\n",
- " \"INT32\": np.dtype(np.int32),\n",
- " \"INT64\": np.dtype(np.int64),\n",
- " \"FP16\": np.dtype(np.float16),\n",
- " \"FP32\": np.dtype(np.float32),\n",
- " \"FP64\": np.dtype(np.float64),\n",
- " \"FP64\": np.dtype(np.double),\n",
- " \"BYTES\": np.dtype(object)\n",
- " }\n",
- "\n",
- " client = grpcclient.InferenceServerClient(triton_uri)\n",
- " model_meta = client.get_model_metadata(model_name)\n",
- " \n",
- " def predict(inputs):\n",
- " if isinstance(inputs, np.ndarray):\n",
- " # single ndarray input\n",
- " request = [grpcclient.InferInput(model_meta.inputs[0].name, inputs.shape, model_meta.inputs[0].datatype)]\n",
- " request[0].set_data_from_numpy(inputs.astype(np_types[model_meta.inputs[0].datatype]))\n",
- " else:\n",
- " # dict of multiple ndarray inputs\n",
- " request = [grpcclient.InferInput(i.name, inputs[i.name].shape, i.datatype) for i in model_meta.inputs]\n",
- " for i in request:\n",
- " i.set_data_from_numpy(inputs[i.name()].astype(np_types[i.datatype()]))\n",
- " \n",
- " response = client.infer(model_name, inputs=request)\n",
- " \n",
- " if len(model_meta.outputs) > 1:\n",
- " # return dictionary of numpy arrays\n",
- " return {o.name: response.as_numpy(o.name) for o in model_meta.outputs}\n",
- " else:\n",
- " # return single numpy array\n",
- " return response.as_numpy(model_meta.outputs[0].name)\n",
- " \n",
- " return predict"
+ " from pytriton.client import ModelClient\n",
+ "\n",
+ " print(f\"CLIENT: Connecting to {model_name} at {url}\")\n",
+ "\n",
+ " def infer_batch(inputs):\n",
+ " with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
+ " encoded_inputs = np.vectorize(lambda x: x.encode(\"utf-8\"))(inputs).astype(np.bytes_)\n",
+ " encoded_inputs = np.expand_dims(encoded_inputs, axis=1)\n",
+ " result_data = client.infer_batch(encoded_inputs)\n",
+ " \n",
+ " return result_data[\"preds\"]\n",
+ " \n",
+ " return infer_batch"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "91974885",
+ "metadata": {},
+ "source": [
+ "#### Load and preprocess DataFrame"
]
},
{
"cell_type": "code",
- "execution_count": 54,
+ "execution_count": 64,
+ "id": "41106a02-236e-4cb3-ac51-76aa64b663c2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df = spark.read.parquet(data_path).limit(512).repartition(8)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e851870b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "25/01/06 22:02:44 WARN CacheManager: Asked to cache already cached data.\n"
+ ]
+ }
+ ],
+ "source": [
+ "input_df = df.select(preprocess(col(\"text\")).alias(\"lines\")).cache()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 66,
"id": "8e06d33f-5cef-4a48-afc3-5d468f8ec2b4",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [],
"source": [
- "from functools import partial\n",
- "\n",
- "classify = predict_batch_udf(partial(triton_fn, triton_uri=\"localhost:8001\", model_name=\"text_classification\"),\n",
- " input_tensor_shapes=[[1]],\n",
+ "classify = predict_batch_udf(partial(triton_fn, url=url, model_name=\"TextModel\"),\n",
" return_type=FloatType(),\n",
- " batch_size=2048)"
+ " batch_size=64)"
]
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 67,
"id": "d89e74ad-e551-4bfa-ad08-98725878630a",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "[Stage 23:> (0 + 8) / 8]\r"
+ ]
+ },
{
"name": "stdout",
"output_type": "stream",
"text": [
- "+--------------------------------------------------------------------------------+----------+\n",
- "| lines| preds|\n",
- "+--------------------------------------------------------------------------------+----------+\n",
- "|i do not understand at all why this movie received such good grades from crit...| 0.5380144|\n",
- "| I am a big fan of The ABC Movies of the Week genre|0.59806347|\n",
- "| Strangeland is a terrible horror/technological thriller|0.54900867|\n",
- "|Sex,Drugs,Rock & Roll is without a doubt the worst product of Western Civiliz...|0.56048334|\n",
- "| Not to be mistaken as the highly touted Samuel L|0.56276447|\n",
- "|Following the pleasingly atmospheric original and the amusingly silly second ...| 0.5571853|\n",
- "| No idea how this is rated as high as it is (5| 0.5637812|\n",
- "|When I saw this in the cinema, I remember wincing at the bad acting about a m...|0.66255826|\n",
- "| I was shocked at how bad it was and unable to turn away from the disaster| 0.5871666|\n",
- "|I don't know if this exceptionally dull movie was intended as an unofficial s...| 0.5578672|\n",
- "|Greetings All,
Isn't it amazing the power that films have on you a...|0.56385136|\n",
- "| I'm sorry but this guy is not funny| 0.5634932|\n",
- "|This movie is so dull I spent half of it on IMDb while it was open in another...|0.58991694|\n",
- "| OK, lets start with the best| 0.5795415|\n",
- "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|0.57494473|\n",
- "|The 3rd in the series finds Paul Kersey (Bronson) turning vigilante to get re...| 0.6133918|\n",
- "| I'm not sure I've ever seen a film as bad as this| 0.5336116|\n",
- "| Steven Seagal has made a really dull, bad and boring movie|0.55780387|\n",
- "| You have to acknowledge Cimino's contribution to cinema| 0.5763774|\n",
- "|***SPOILER ALERT*** Disjointed and confusing arson drama that has to do with ...|0.56471467|\n",
- "+--------------------------------------------------------------------------------+----------+\n",
- "only showing top 20 rows\n",
- "\n",
- "CPU times: user 2.49 ms, sys: 1.47 ms, total: 3.96 ms\n",
- "Wall time: 916 ms\n"
+ "CPU times: user 10.4 ms, sys: 7 ms, total: 17.4 ms\n",
+ "Wall time: 2.53 s\n"
]
},
{
@@ -1759,18 +2087,57 @@
],
"source": [
"%%time\n",
- "df.withColumn(\"preds\", classify(struct(*columns))).show(truncate=80)"
+ "predictions = input_df.withColumn(\"preds\", classify(struct(\"lines\")))\n",
+ "results = predictions.collect()"
]
},
{
"cell_type": "code",
- "execution_count": 56,
+ "execution_count": 68,
"id": "b4fa7fc9-341c-49a6-9af2-e316f2355d67",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 2.59 ms, sys: 631 μs, total: 3.22 ms\n",
+ "Wall time: 214 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions = input_df.withColumn(\"preds\", classify(\"lines\"))\n",
+ "results = predictions.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 69,
+ "id": "564f999b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CPU times: user 288 μs, sys: 3.66 ms, total: 3.95 ms\n",
+ "Wall time: 245 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "predictions = input_df.withColumn(\"preds\", classify(col(\"lines\")))\n",
+ "results = predictions.collect()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 70,
+ "id": "9222e8a9",
+ "metadata": {},
"outputs": [
{
"name": "stdout",
@@ -1779,37 +2146,34 @@
"+--------------------------------------------------------------------------------+----------+\n",
"| lines| preds|\n",
"+--------------------------------------------------------------------------------+----------+\n",
- "|i do not understand at all why this movie received such good grades from crit...| 0.5380144|\n",
- "| I am a big fan of The ABC Movies of the Week genre|0.59806347|\n",
- "| Strangeland is a terrible horror/technological thriller|0.54900867|\n",
- "|Sex,Drugs,Rock & Roll is without a doubt the worst product of Western Civiliz...|0.56048334|\n",
- "| Not to be mistaken as the highly touted Samuel L|0.56276447|\n",
- "|Following the pleasingly atmospheric original and the amusingly silly second ...| 0.5571853|\n",
- "| No idea how this is rated as high as it is (5| 0.5637812|\n",
- "|When I saw this in the cinema, I remember wincing at the bad acting about a m...|0.66255826|\n",
- "| I was shocked at how bad it was and unable to turn away from the disaster| 0.5871666|\n",
- "|I don't know if this exceptionally dull movie was intended as an unofficial s...| 0.5578672|\n",
- "|Greetings All,
Isn't it amazing the power that films have on you a...|0.56385136|\n",
- "| I'm sorry but this guy is not funny| 0.5634932|\n",
- "|This movie is so dull I spent half of it on IMDb while it was open in another...|0.58991694|\n",
- "| OK, lets start with the best| 0.5795415|\n",
- "|This show had a promising start as sort of the opposite of 'Oceans 11' but ha...|0.57494473|\n",
- "|The 3rd in the series finds Paul Kersey (Bronson) turning vigilante to get re...| 0.6133918|\n",
- "| I'm not sure I've ever seen a film as bad as this| 0.5336116|\n",
- "| Steven Seagal has made a really dull, bad and boring movie|0.55780387|\n",
- "| You have to acknowledge Cimino's contribution to cinema| 0.5763774|\n",
- "|***SPOILER ALERT*** Disjointed and confusing arson drama that has to do with ...|0.56471467|\n",
+ "|The only reason I'm even giving this movie a 4 is because it was made in to a...| 0.5441438|\n",
+ "|Awkward disaster mishmash has a team of scavengers coming across the overturn...|0.58016133|\n",
+ "|Here is a fantastic concept for a film - a series of meteors crash into a sma...|0.55131954|\n",
+ "| I walked out of the cinema having suffered this film after 30 mins| 0.542057|\n",
+ "|A wildly uneven film where the major problem is the uneasy mix of comedy and ...| 0.5196002|\n",
+ "|Leonard Rossiter and Frances de la Tour carry this film, not without a strugg...|0.53112733|\n",
+ "| A good cast| 0.5486873|\n",
+ "|Yet again, I appear to be the only person on planet Earth who is capable of c...| 0.5343111|\n",
+ "|As a serious horror fan, I get that certain marketing ploys are used to sell ...| 0.5497148|\n",
+ "|Upon writing this review I have difficulty trying to think of what to write a...| 0.5581456|\n",
+ "| Simply awful| 0.5701754|\n",
+ "|I am a fan of Ed Harris' work and I really had high expectations about this film| 0.5510578|\n",
+ "| Well|0.55721515|\n",
+ "| This is a new approach to comedy|0.56038314|\n",
+ "| It's been mentioned by others the inane dialogue in this series and I agree| 0.5451202|\n",
+ "|One of the most boring movies I've ever had to sit through, it's completely f...|0.56161135|\n",
+ "|This movie was playing on Lifetime Movie Network last month and I decided to ...| 0.5555233|\n",
+ "| 1983's \"Frightmare\" is an odd little film| 0.5363368|\n",
+ "| 'Felony' is a B-movie|0.55682427|\n",
+ "| This movie defines the word \"confused\"| 0.5630136|\n",
"+--------------------------------------------------------------------------------+----------+\n",
"only showing top 20 rows\n",
- "\n",
- "CPU times: user 571 μs, sys: 2.22 ms, total: 2.79 ms\n",
- "Wall time: 528 ms\n"
+ "\n"
]
}
],
"source": [
- "%%time\n",
- "df.withColumn(\"preds\", classify(*columns)).show(truncate=80)"
+ "predictions.show(truncate=80)"
]
},
{
@@ -1824,14 +2188,17 @@
},
{
"cell_type": "code",
- "execution_count": 57,
+ "execution_count": 71,
"id": "a71ac9b6-47a2-4306-bc40-9ce7b4e968ec",
- "metadata": {
- "tags": [
- "TRITON"
- ]
- },
+ "metadata": {},
"outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Reqesting stage-level resources: (cores=5, gpu=1.0)\n"
+ ]
+ },
{
"name": "stderr",
"output_type": "stream",
@@ -1845,31 +2212,39 @@
"[True]"
]
},
- "execution_count": 57,
+ "execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "def stop_triton(it):\n",
- " import docker\n",
- " import time\n",
+ "def stop_triton(pids):\n",
+ " import os\n",
+ " import socket\n",
+ " import signal\n",
+ " import time \n",
" \n",
- " client=docker.from_env()\n",
- " containers=client.containers.list(filters={\"name\": \"spark-triton\"})\n",
- " print(\">>>> stopping containers: {}\".format([c.short_id for c in containers]))\n",
- " if containers:\n",
- " container=containers[0]\n",
- " container.stop(timeout=120)\n",
+ " hostname = socket.gethostname()\n",
+ " pid = pids.get(hostname, None)\n",
+ " assert pid is not None, f\"Could not find pid for {hostname}\"\n",
+ " \n",
+ " for _ in range(5):\n",
+ " try:\n",
+ " os.kill(pid, signal.SIGTERM)\n",
+ " except OSError:\n",
+ " return [True]\n",
+ " time.sleep(5)\n",
"\n",
- " return [True]\n",
+ " return [False]\n",
"\n",
- "nodeRDD.barrier().mapPartitions(stop_triton).collect()"
+ "shutdownRDD = sc.parallelize(list(range(num_nodes)), num_nodes)\n",
+ "shutdownRDD = _use_stage_level_scheduling(spark, shutdownRDD)\n",
+ "shutdownRDD.barrier().mapPartitions(lambda _: stop_triton(pids)).collect()"
]
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": 73,
"id": "54a90574-7cbb-487b-b7a8-dcda0e6e301f",
"metadata": {},
"outputs": [],