diff --git a/application/pom.xml b/application/pom.xml index 6159e557..9394bc67 100644 --- a/application/pom.xml +++ b/application/pom.xml @@ -227,6 +227,11 @@ dh-runtime-python ${revision} + + it.smartcommunitylabdhub + dh-runtime-model-serve + ${revision} + it.smartcommunitylabdhub dh-console diff --git a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/HuggingFaceModelSpec.java b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/HuggingFaceModelSpec.java new file mode 100644 index 00000000..cf663f4a --- /dev/null +++ b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/HuggingFaceModelSpec.java @@ -0,0 +1,11 @@ + +package it.smartcommunitylabdhub.core.models.specs.model; + + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; + + +@SpecType(kind = "huggingface", entity = EntityName.MODEL) +public class HuggingFaceModelSpec extends it.smartcommunitylabdhub.commons.models.entities.model.HuggingFaceModelSpec { +} diff --git a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/MlflowModelSpec.java b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/MlflowModelSpec.java new file mode 100644 index 00000000..3ff85511 --- /dev/null +++ b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/MlflowModelSpec.java @@ -0,0 +1,9 @@ +package it.smartcommunitylabdhub.core.models.specs.model; + + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +@SpecType(kind = "mlflow", entity = EntityName.MODEL) +public class MlflowModelSpec extends it.smartcommunitylabdhub.commons.models.entities.model.MlflowModelSpec { + +} \ No newline at end of file diff --git a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/ModelSpec.java b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/ModelSpec.java index 733d3e47..000c05d5 100644 --- a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/ModelSpec.java +++ b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/ModelSpec.java @@ -1,37 +1,7 @@ package it.smartcommunitylabdhub.core.models.specs.model; -import com.fasterxml.jackson.annotation.JsonProperty; import it.smartcommunitylabdhub.commons.annotations.common.SpecType; -import it.smartcommunitylabdhub.commons.models.entities.model.ModelBaseSpec; import it.smartcommunitylabdhub.commons.models.enums.EntityName; -import java.io.Serializable; -import java.util.LinkedHashMap; -import java.util.Map; -import lombok.Getter; -import lombok.Setter; - -@Getter -@Setter @SpecType(kind = "model", entity = EntityName.MODEL) -public class ModelSpec extends ModelBaseSpec { - - @JsonProperty("base_model") - private String baseModel; - - @JsonProperty("parameters") - private Map parameters = new LinkedHashMap<>(); - - @JsonProperty("metrics") - private Map metrics = new LinkedHashMap<>(); - - @Override - public void configure(Map data) { - super.configure(data); - - ModelSpec spec = mapper.convertValue(data, ModelSpec.class); - - this.baseModel = spec.getBaseModel(); - this.parameters = spec.getParameters(); - this.metrics = spec.getMetrics(); - } +public class ModelSpec extends it.smartcommunitylabdhub.commons.models.entities.model.ModelBaseSpec { } diff --git a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/SKLearnModelSpec.java b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/SKLearnModelSpec.java new file mode 100644 index 00000000..820a6ba5 --- /dev/null +++ b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/SKLearnModelSpec.java @@ -0,0 +1,9 @@ +package it.smartcommunitylabdhub.core.models.specs.model; + + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; + +@SpecType(kind = "sklearn", entity = EntityName.MODEL) +public class SKLearnModelSpec extends it.smartcommunitylabdhub.commons.models.entities.model.SKLearnModelSpec { +} \ No newline at end of file diff --git a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/mlflow/Dataset.java b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/mlflow/Dataset.java new file mode 100644 index 00000000..addd12cb --- /dev/null +++ b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/mlflow/Dataset.java @@ -0,0 +1,18 @@ +package it.smartcommunitylabdhub.core.models.specs.model.mlflow; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class Dataset { + + private String name; + private String digest; + private String profile; + private String schema; + private String source; + @JsonProperty("source_type") + private String sourceType; +} \ No newline at end of file diff --git a/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/mlflow/Signature.java b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/mlflow/Signature.java new file mode 100644 index 00000000..9575227a --- /dev/null +++ b/application/src/main/java/it/smartcommunitylabdhub/core/models/specs/model/mlflow/Signature.java @@ -0,0 +1,13 @@ +package it.smartcommunitylabdhub.core.models.specs.model.mlflow; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class Signature { + + private String inputs; + private String outputs; + private String params; +} \ No newline at end of file diff --git a/application/src/main/resources/application.yml b/application/src/main/resources/application.yml index 79b53857..d9dfdc31 100644 --- a/application/src/main/resources/application.yml +++ b/application/src/main/resources/application.yml @@ -71,7 +71,12 @@ runtime: PYTHON3_9: ${RUNTIME_PYTHON_IMAGE_3_9:ghcr.io/scc-digitalhub/digitalhub-serverless/python-runtime:3.9} PYTHON3_10: ${RUNTIME_PYTHON_IMAGE_3_10:ghcr.io/scc-digitalhub/digitalhub-serverless/python-runtime:3.10} command: /usr/local/bin/processor - + sklearnserve: + image: ${RUNTIME_SKLEARN_SERVE_IMAGE:kserve/sklearnserver:latest} + mlflowserve: + image: ${RUNTIME_MLFLOW_SERVE_IMAGE:seldonio/mlserver:1.6.0-mlflow} + huggingfaceserve: + image: ${RUNTIME_HUGGINGFACE_SERVE_IMAGE:kserve/huggingfaceserver:latest} # Spring Docs springdoc: diff --git a/core-builder-tool/builder-tool.sh b/core-builder-tool/builder-tool.sh index 66bd617d..a43249c1 100755 --- a/core-builder-tool/builder-tool.sh +++ b/core-builder-tool/builder-tool.sh @@ -213,6 +213,13 @@ if [ -f "$source_dir/context-refs.txt" ]; then curl -o "$destination_dir/$destination" -L "$source" unzip "$destination_dir/$destination" -d "$destination_dir" ;; + "s3") # for now accept a folder/path + mc alias set $minio $S3_ENDPOINT_URL $AWS_ACCESS_KEY_ID $AWS_SECRET_ACCESS_KEY + echo "Protocol: $protocol" + echo "Downloading $minio/$rebuilt_url" + echo "to $destination_dir/$destination" + mc cp --recursive "$minio/$rebuilt_url" "$destination_dir/$destination" + ;; # Add more cases for other protocols as needed *) echo "Unknown protocol: $protocol" diff --git a/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/HuggingFaceModelSpec.java b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/HuggingFaceModelSpec.java new file mode 100644 index 00000000..2fc31eb8 --- /dev/null +++ b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/HuggingFaceModelSpec.java @@ -0,0 +1,32 @@ +package it.smartcommunitylabdhub.commons.models.entities.model; + +import lombok.Getter; +import lombok.Setter; + +import java.io.Serializable; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; + +@Getter +@Setter +public class HuggingFaceModelSpec extends ModelSpec { + + //Huggingface model id + @JsonProperty("model_id") + private String modelId; + + //Huggingface model revision + @JsonProperty("model_revision") + private String modelRevision; + + @Override + public void configure(Map data) { + super.configure(data); + + HuggingFaceModelSpec spec = mapper.convertValue(data, HuggingFaceModelSpec.class); + + this.modelId = spec.getModelId(); + this.modelRevision = spec.getModelRevision(); + } +} diff --git a/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/MlflowModelSpec.java b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/MlflowModelSpec.java new file mode 100644 index 00000000..50240307 --- /dev/null +++ b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/MlflowModelSpec.java @@ -0,0 +1,40 @@ +package it.smartcommunitylabdhub.commons.models.entities.model; + + +import lombok.Getter; +import lombok.Setter; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import it.smartcommunitylabdhub.commons.models.entities.model.mlflow.Dataset; +import it.smartcommunitylabdhub.commons.models.entities.model.mlflow.Signature; + +@Getter +@Setter +public class MlflowModelSpec extends ModelSpec { + + private String flavor; + + @JsonProperty("model_config") + private Map modelConfig; + + @JsonProperty("input_datasets") + private List inputDatasets; + + private Signature signature; + + @Override + public void configure(Map data) { + super.configure(data); + + MlflowModelSpec spec = mapper.convertValue(data, MlflowModelSpec.class); + this.flavor = spec.getFlavor(); + this.signature = spec.getSignature(); + this.inputDatasets = spec.getInputDatasets(); + this.modelConfig = spec.getModelConfig(); + } +} \ No newline at end of file diff --git a/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/ModelSpec.java b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/ModelSpec.java new file mode 100644 index 00000000..71f7f7cb --- /dev/null +++ b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/ModelSpec.java @@ -0,0 +1,33 @@ +package it.smartcommunitylabdhub.commons.models.entities.model; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.io.Serializable; +import java.util.LinkedHashMap; +import java.util.Map; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class ModelSpec extends ModelBaseSpec { + + @JsonProperty("base_model") + private String baseModel; + + @JsonProperty("parameters") + private Map parameters = new LinkedHashMap<>(); + + @JsonProperty("metrics") + private Map metrics = new LinkedHashMap<>(); + + @Override + public void configure(Map data) { + super.configure(data); + + ModelSpec spec = mapper.convertValue(data, ModelSpec.class); + + this.baseModel = spec.getBaseModel(); + this.parameters = spec.getParameters(); + this.metrics = spec.getMetrics(); + } +} diff --git a/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/SKLearnModelSpec.java b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/SKLearnModelSpec.java new file mode 100644 index 00000000..0172d77a --- /dev/null +++ b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/SKLearnModelSpec.java @@ -0,0 +1,10 @@ +package it.smartcommunitylabdhub.commons.models.entities.model; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class SKLearnModelSpec extends ModelSpec { + +} \ No newline at end of file diff --git a/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/mlflow/Dataset.java b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/mlflow/Dataset.java new file mode 100644 index 00000000..1ee1a187 --- /dev/null +++ b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/mlflow/Dataset.java @@ -0,0 +1,18 @@ +package it.smartcommunitylabdhub.commons.models.entities.model.mlflow; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class Dataset { + + private String name; + private String digest; + private String profile; + private String schema; + private String source; + @JsonProperty("source_type") + private String sourceType; +} \ No newline at end of file diff --git a/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/mlflow/Signature.java b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/mlflow/Signature.java new file mode 100644 index 00000000..f5572e1c --- /dev/null +++ b/modules/commons/src/main/java/it/smartcommunitylabdhub/commons/models/entities/model/mlflow/Signature.java @@ -0,0 +1,13 @@ +package it.smartcommunitylabdhub.commons.models.entities.model.mlflow; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class Signature { + + private String inputs; + private String outputs; + private String params; +} \ No newline at end of file diff --git a/modules/framework-k8s/src/main/java/it/smartcommunitylabdhub/framework/k8s/infrastructure/k8s/K8sServeFramework.java b/modules/framework-k8s/src/main/java/it/smartcommunitylabdhub/framework/k8s/infrastructure/k8s/K8sServeFramework.java index 1287ed80..24ce427f 100644 --- a/modules/framework-k8s/src/main/java/it/smartcommunitylabdhub/framework/k8s/infrastructure/k8s/K8sServeFramework.java +++ b/modules/framework-k8s/src/main/java/it/smartcommunitylabdhub/framework/k8s/infrastructure/k8s/K8sServeFramework.java @@ -304,7 +304,7 @@ public V1Service build(K8sServeRunnable runnable) throws K8sFrameworkException { .stream() .filter(p -> p.port() != null && p.targetPort() != null) .map(p -> - new V1ServicePort().port(p.port()).targetPort(new IntOrString(p.targetPort())).protocol("TCP") + new V1ServicePort().port(p.port()).targetPort(new IntOrString(p.targetPort())).protocol("TCP").name("port" + p.port()) ) .collect(Collectors.toList()) ) diff --git a/modules/runtime-model-serve/.flattened-pom.xml b/modules/runtime-model-serve/.flattened-pom.xml new file mode 100644 index 00000000..61d711a5 --- /dev/null +++ b/modules/runtime-model-serve/.flattened-pom.xml @@ -0,0 +1,77 @@ + + + 4.0.0 + it.smartcommunitylabdhub + dh-runtime-model-serve + 0.7.0-SNAPSHOT + + + it.smartcommunitylabdhub + dh-commons + 0.7.0-SNAPSHOT + compile + + + it.smartcommunitylabdhub + dh-framework-k8s + 0.7.0-SNAPSHOT + compile + + + it.smartcommunitylabdhub + dh-framework-kaniko + 0.7.0-SNAPSHOT + compile + + + org.projectlombok + lombok + 1.18.34 + compile + true + + + com.fasterxml.jackson.core + jackson-core + 2.16.2 + compile + + + com.fasterxml.jackson.dataformat + jackson-dataformat-cbor + 2.16.2 + compile + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + 2.16.2 + compile + + + com.fasterxml.jackson.module + jackson-module-jsonSchema + 2.16.2 + compile + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + 2.16.2 + compile + + + org.slf4j + slf4j-api + 2.0.9 + compile + + + org.slf4j + log4j-over-slf4j + 2.0.9 + compile + + + diff --git a/modules/runtime-model-serve/pom.xml b/modules/runtime-model-serve/pom.xml new file mode 100644 index 00000000..6409b8d1 --- /dev/null +++ b/modules/runtime-model-serve/pom.xml @@ -0,0 +1,78 @@ + + + 4.0.0 + + it.smartcommunitylabdhub + digitalhub-core + ${revision} + ../../ + + it.smartcommunitylabdhub + dh-runtime-model-serve + runtime-python + DHCore runtime-model-serve + + + + it.smartcommunitylabdhub + dh-commons + ${revision} + + + it.smartcommunitylabdhub + dh-framework-k8s + ${revision} + + + it.smartcommunitylabdhub + dh-framework-kaniko + ${revision} + + + + org.projectlombok + lombok + ${lombok.version} + compile + true + + + com.fasterxml.jackson.core + jackson-core + ${jackson.version} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-cbor + ${jackson.version} + + + com.fasterxml.jackson.dataformat + jackson-dataformat-yaml + ${jackson.version} + + + com.fasterxml.jackson.module + jackson-module-jsonSchema + ${jackson.version} + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + ${jackson.version} + + + org.slf4j + slf4j-api + ${slf4j.version} + + + org.slf4j + log4j-over-slf4j + ${slf4j.version} + + + + \ No newline at end of file diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/HuggingfaceServeRunner.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/HuggingfaceServeRunner.java new file mode 100644 index 00000000..a29f48be --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/HuggingfaceServeRunner.java @@ -0,0 +1,222 @@ +package it.smartcommunitylabdhub.runtime.huggingface; + +import it.smartcommunitylabdhub.commons.accessors.spec.TaskSpecAccessor; +import it.smartcommunitylabdhub.commons.infrastructure.Runner; +import it.smartcommunitylabdhub.commons.models.entities.run.Run; +import it.smartcommunitylabdhub.commons.models.enums.State; +import it.smartcommunitylabdhub.commons.models.utils.TaskUtils; +import it.smartcommunitylabdhub.framework.k8s.kubernetes.K8sBuilderHelper; +import it.smartcommunitylabdhub.framework.k8s.model.ContextRef; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreEnv; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreLabel; +import it.smartcommunitylabdhub.framework.k8s.objects.CorePort; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreServiceType; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sRunnable; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sServeRunnable; +import it.smartcommunitylabdhub.runtime.huggingface.specs.HuggingfaceServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.huggingface.specs.HuggingfaceServeRunSpec; +import it.smartcommunitylabdhub.runtime.huggingface.specs.HuggingfaceServeTaskSpec; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +public class HuggingfaceServeRunner implements Runner { + + private static final int HTTP_PORT = 8080; + private static final int GRPC_PORT = 8081; + + private final String image; + private final HuggingfaceServeFunctionSpec functionSpec; + private final Map> groupedSecrets; + + private final K8sBuilderHelper k8sBuilderHelper; + + public HuggingfaceServeRunner( + String image, + HuggingfaceServeFunctionSpec functionSpec, + Map> groupedSecrets, + K8sBuilderHelper k8sBuilderHelper + ) { + this.image = image; + this.functionSpec = functionSpec; + this.groupedSecrets = groupedSecrets; + this.k8sBuilderHelper = k8sBuilderHelper; + } + + @Override + public K8sRunnable produce(Run run) { + HuggingfaceServeRunSpec runSpec = HuggingfaceServeRunSpec.with(run.getSpec()); + HuggingfaceServeTaskSpec taskSpec = runSpec.getTaskServeSpec(); + TaskSpecAccessor taskAccessor = TaskUtils.parseFunction(taskSpec.getFunction()); + + List coreEnvList = new ArrayList<>( + List.of(new CoreEnv("PROJECT_NAME", run.getProject()), new CoreEnv("RUN_ID", run.getId())) + ); + + Optional.ofNullable(taskSpec.getEnvs()).ifPresent(coreEnvList::addAll); + + UriComponents uri = UriComponentsBuilder.fromUriString(functionSpec.getPath()).build(); + + //read source and build context + List contextRefs = null; + + List args = new ArrayList<>( + List.of( + "-m", + "huggingfaceserver", + "--model_name", + StringUtils.hasText(functionSpec.getModelName()) ? functionSpec.getModelName() : "model", + "--protocol", + "v2", + "--enable_docs_url", + "true" + ) + ); + + // model dir or model id + if (!"huggingface".equals(uri.getScheme())) { + args.add("--model_dir"); + args.add("/shared/model"); + contextRefs = + Collections.singletonList( + ContextRef + .builder() + .source(functionSpec.getPath()) + .protocol(uri.getScheme()) + .destination("model") + .build() + ); + } else { + String mdlId = uri.getHost() + uri.getPath(); + String revision = null; + if (mdlId.contains(":")) { + String[] parts = mdlId.split(":"); + mdlId = parts[0]; + revision = parts[1]; + } + args.add("--model_id"); + args.add(mdlId); + if (revision != null) { + args.add("--model_revision"); + args.add(revision); + } + contextRefs = Collections.emptyList(); + } + // tokenizer revision + if (StringUtils.hasText(taskSpec.getTokenizerRevision())) { + args.add("--tokenizer_revision"); + args.add(taskSpec.getTokenizerRevision()); + } + // max length + if (taskSpec.getMaxLength() != null) { + args.add("--max_length"); + args.add(taskSpec.getMaxLength().toString()); + } + // disable_lower_case + if (taskSpec.getDisableLowerCase() != null) { + args.add("--disable_lower_case"); + args.add(taskSpec.getDisableLowerCase().toString()); + } + // disable_special_tokens + if (taskSpec.getDisableSpecialTokens() != null) { + args.add("--disable_special_tokens"); + args.add(taskSpec.getDisableSpecialTokens().toString()); + } + // trust_remote_code + if (taskSpec.getTrustRemoteCode() != null) { + args.add("--trust_remote_code"); + args.add(taskSpec.getTrustRemoteCode().toString()); + } else { + args.add("--trust_remote_code"); + args.add("true"); + } + // tensor_input_names + if (taskSpec.getTensorInputNames() != null) { + args.add("--tensor_input_names"); + args.add(StringUtils.collectionToCommaDelimitedString(taskSpec.getTensorInputNames())); + } + // task + if (taskSpec.getHuggingfaceTask() != null) { + args.add("--task"); + args.add(taskSpec.getHuggingfaceTask().getTask()); + } + // backend + if (taskSpec.getBackend() != null) { + args.add("--backend"); + args.add(taskSpec.getBackend().getBackend()); + } + // return_token_type_ids + if (taskSpec.getReturnTokenTypeIds() != null) { + args.add("--return_token_type_ids"); + args.add(taskSpec.getReturnTokenTypeIds().toString()); + } + // return_probabilities + if (taskSpec.getReturnProbabilities() != null) { + args.add("--return_probabilities"); + args.add(taskSpec.getReturnProbabilities().toString()); + } + // disable_log_requests + if (taskSpec.getDisableLogRequests() != null) { + args.add("--disable_log_requests"); + args.add(taskSpec.getDisableLogRequests().toString()); + } + // max_log_len + if (taskSpec.getMaxLogLen() != null) { + args.add("--max_log_len"); + args.add(taskSpec.getMaxLogLen().toString()); + } + // dtype + if (taskSpec.getDtype() != null) { + args.add("--dtype"); + args.add(taskSpec.getDtype().getDType()); + } + + CorePort servicePort = new CorePort(HTTP_PORT, HTTP_PORT); + CorePort grpcPort = new CorePort(GRPC_PORT, GRPC_PORT); + + String img = StringUtils.hasText(functionSpec.getImage()) ? functionSpec.getImage() : image; + + //build runnable + K8sRunnable k8sServeRunnable = K8sServeRunnable + .builder() + .runtime(HuggingfaceServeRuntime.RUNTIME) + .task(HuggingfaceServeTaskSpec.KIND) + .state(State.READY.name()) + .labels( + k8sBuilderHelper != null + ? List.of(new CoreLabel(k8sBuilderHelper.getLabelName("function"), taskAccessor.getFunction())) + : null + ) + //base + .image(img) + .command("python") + .args(args.toArray(new String[0])) + .contextRefs(contextRefs) + .envs(coreEnvList) + .secrets(groupedSecrets) + .resources(taskSpec.getResources()) + .volumes(taskSpec.getVolumes()) + .nodeSelector(taskSpec.getNodeSelector()) + .affinity(taskSpec.getAffinity()) + .tolerations(taskSpec.getTolerations()) + .runtimeClass(taskSpec.getRuntimeClass()) + .priorityClass(taskSpec.getPriorityClass()) + .template(taskSpec.getProfile()) + //specific + .replicas(taskSpec.getReplicas()) + .servicePorts(List.of(servicePort, grpcPort)) + .serviceType(taskSpec.getServiceType() != null ? taskSpec.getServiceType() : CoreServiceType.NodePort) + .build(); + + k8sServeRunnable.setId(run.getId()); + k8sServeRunnable.setProject(run.getProject()); + + return k8sServeRunnable; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/HuggingfaceServeRuntime.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/HuggingfaceServeRuntime.java new file mode 100644 index 00000000..01493362 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/HuggingfaceServeRuntime.java @@ -0,0 +1,110 @@ +package it.smartcommunitylabdhub.runtime.huggingface; + +import it.smartcommunitylabdhub.commons.accessors.spec.RunSpecAccessor; +import it.smartcommunitylabdhub.commons.annotations.infrastructure.RuntimeComponent; +import it.smartcommunitylabdhub.commons.models.base.Executable; +import it.smartcommunitylabdhub.commons.models.entities.run.Run; +import it.smartcommunitylabdhub.commons.models.entities.task.Task; +import it.smartcommunitylabdhub.commons.models.entities.task.TaskBaseSpec; +import it.smartcommunitylabdhub.commons.models.utils.RunUtils; +import it.smartcommunitylabdhub.commons.services.entities.SecretService; +import it.smartcommunitylabdhub.framework.k8s.base.K8sBaseRuntime; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sRunnable; +import it.smartcommunitylabdhub.runtime.huggingface.specs.HuggingfaceServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.huggingface.specs.HuggingfaceServeRunSpec; +import it.smartcommunitylabdhub.runtime.huggingface.specs.HuggingfaceServeTaskSpec; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeRunStatus; +import jakarta.validation.constraints.NotNull; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; + +@Slf4j +@RuntimeComponent(runtime = HuggingfaceServeRuntime.RUNTIME) +public class HuggingfaceServeRuntime + extends K8sBaseRuntime { + + public static final String RUNTIME = "huggingfaceserve"; + + @Autowired + private SecretService secretService; + + @Value("${runtime.huggingfaceserve.image}") + private String image; + + public HuggingfaceServeRuntime() { + super(HuggingfaceServeRunSpec.KIND); + } + + @Override + public HuggingfaceServeRunSpec build(@NotNull Executable function, @NotNull Task task, @NotNull Run run) { + //check run kind + if (!HuggingfaceServeRunSpec.KIND.equals(run.getKind())) { + throw new IllegalArgumentException( + "Run kind {} unsupported, expecting {}".formatted( + String.valueOf(run.getKind()), + HuggingfaceServeRunSpec.KIND + ) + ); + } + + HuggingfaceServeFunctionSpec funSpec = HuggingfaceServeFunctionSpec.with(function.getSpec()); + HuggingfaceServeRunSpec runSpec = HuggingfaceServeRunSpec.with(run.getSpec()); + + String kind = task.getKind(); + + //build task spec as defined + TaskBaseSpec taskSpec = + switch (kind) { + case HuggingfaceServeTaskSpec.KIND -> { + yield HuggingfaceServeTaskSpec.with(task.getSpec()); + } + default -> throw new IllegalArgumentException( + "Kind not recognized. Cannot retrieve the right builder or specialize Spec for Run and Task." + ); + }; + + //build run merging task spec overrides + Map map = new HashMap<>(); + map.putAll(runSpec.toMap()); + taskSpec.toMap().forEach(map::putIfAbsent); + + HuggingfaceServeRunSpec serveSpec = HuggingfaceServeRunSpec.with(map); + //ensure function is not modified + serveSpec.setFunctionSpec(funSpec); + + return serveSpec; + } + + @Override + public K8sRunnable run(@NotNull Run run) { + //check run kind + if (!HuggingfaceServeRunSpec.KIND.equals(run.getKind())) { + throw new IllegalArgumentException( + "Run kind {} unsupported, expecting {}".formatted( + String.valueOf(run.getKind()), + HuggingfaceServeRunSpec.KIND + ) + ); + } + + HuggingfaceServeRunSpec runSpec = HuggingfaceServeRunSpec.with(run.getSpec()); + + // Create string run accessor from task + RunSpecAccessor runAccessor = RunUtils.parseTask(runSpec.getTask()); + + return switch (runAccessor.getTask()) { + case HuggingfaceServeTaskSpec.KIND -> new HuggingfaceServeRunner( + image, + runSpec.getFunctionSpec(), + secretService.groupSecrets(run.getProject(), runSpec.getTaskServeSpec().getSecrets()), + k8sBuilderHelper + ) + .produce(run); + default -> throw new IllegalArgumentException("Kind not recognized. Cannot retrieve the right Runner"); + }; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceBackend.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceBackend.java new file mode 100644 index 00000000..0d42af01 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceBackend.java @@ -0,0 +1,17 @@ +package it.smartcommunitylabdhub.runtime.huggingface.models; + +public enum HuggingfaceBackend { + AUTO("auto"), + VLLM("vllm"), + HUGGINGFACE("huggingface"); + + private final String backend; + + HuggingfaceBackend(String backend) { + this.backend = backend; + } + + public String getBackend() { + return backend; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceDType.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceDType.java new file mode 100644 index 00000000..c478687a --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceDType.java @@ -0,0 +1,20 @@ +package it.smartcommunitylabdhub.runtime.huggingface.models; + +public enum HuggingfaceDType { + AUTO("auto"), + FLOAT16("float16"), + FLOAT32("float32"), + BFLOAT16("bfloat16"), + FLOAT("float"), + HALF("half"); + + private final String dtype; + + HuggingfaceDType(String dtype) { + this.dtype = dtype; + } + + public String getDType() { + return dtype; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceTask.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceTask.java new file mode 100644 index 00000000..b0b25a83 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/models/HuggingfaceTask.java @@ -0,0 +1,19 @@ +package it.smartcommunitylabdhub.runtime.huggingface.models; + +public enum HuggingfaceTask { + sequence_classification("sequence-classification"), + token_classification("token-classification"), + fill_mask("fill-mask"), + text_generation("text-generation"), + text2text_generation("text2text-generation"); + + private final String task; + + HuggingfaceTask(String task) { + this.task = task; + } + + public String getTask() { + return task; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeFunctionSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeFunctionSpec.java new file mode 100644 index 00000000..86dc779b --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeFunctionSpec.java @@ -0,0 +1,28 @@ +package it.smartcommunitylabdhub.runtime.huggingface.specs; + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.huggingface.HuggingfaceServeRuntime; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeFunctionSpec; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType( + runtime = HuggingfaceServeRuntime.RUNTIME, + kind = HuggingfaceServeRuntime.RUNTIME, + entity = EntityName.FUNCTION +) +public class HuggingfaceServeFunctionSpec extends ModelServeFunctionSpec { + + public static HuggingfaceServeFunctionSpec with(Map data) { + HuggingfaceServeFunctionSpec spec = new HuggingfaceServeFunctionSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeRunSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeRunSpec.java new file mode 100644 index 00000000..39db58e6 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeRunSpec.java @@ -0,0 +1,51 @@ +package it.smartcommunitylabdhub.runtime.huggingface.specs; + +import com.fasterxml.jackson.annotation.JsonUnwrapped; +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.entities.run.RunBaseSpec; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.huggingface.HuggingfaceServeRuntime; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = HuggingfaceServeRuntime.RUNTIME, kind = HuggingfaceServeRunSpec.KIND, entity = EntityName.RUN) +public class HuggingfaceServeRunSpec extends RunBaseSpec { + + @JsonUnwrapped + private HuggingfaceServeFunctionSpec functionSpec; + + @JsonUnwrapped + private HuggingfaceServeTaskSpec taskServeSpec; + + public static final String KIND = HuggingfaceServeRuntime.RUNTIME + "+run"; + + @Override + public void configure(Map data) { + super.configure(data); + + HuggingfaceServeRunSpec spec = mapper.convertValue(data, HuggingfaceServeRunSpec.class); + + this.functionSpec = spec.getFunctionSpec(); + this.taskServeSpec = spec.getTaskServeSpec(); + } + + public void setFunctionSpec(HuggingfaceServeFunctionSpec functionSpec) { + this.functionSpec = functionSpec; + } + + public void setTaskServeSpec(HuggingfaceServeTaskSpec taskServeSpec) { + this.taskServeSpec = taskServeSpec; + } + + public static HuggingfaceServeRunSpec with(Map data) { + HuggingfaceServeRunSpec spec = new HuggingfaceServeRunSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeTaskSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeTaskSpec.java new file mode 100644 index 00000000..ba430085 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/huggingface/specs/HuggingfaceServeTaskSpec.java @@ -0,0 +1,141 @@ +package it.smartcommunitylabdhub.runtime.huggingface.specs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.huggingface.HuggingfaceServeRuntime; +import it.smartcommunitylabdhub.runtime.huggingface.models.HuggingfaceBackend; +import it.smartcommunitylabdhub.runtime.huggingface.models.HuggingfaceDType; +import it.smartcommunitylabdhub.runtime.huggingface.models.HuggingfaceTask; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeServeTaskSpec; +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = HuggingfaceServeRuntime.RUNTIME, kind = HuggingfaceServeTaskSpec.KIND, entity = EntityName.TASK) +public class HuggingfaceServeTaskSpec extends ModelServeServeTaskSpec { + + public static final String KIND = HuggingfaceServeRuntime.RUNTIME + "+serve"; + + // The ML task name + @JsonProperty("huggingface_task") + @Schema(title = "fields.huggingface.task.title", description = "fields.huggingface.task.description") + private HuggingfaceTask huggingfaceTask; + + // the backend to use to load the model. + @JsonProperty("backend") + @Schema(title = "fields.huggingface.backend.title", description = "fields.huggingface.backend.description") + private HuggingfaceBackend backend; + + // Huggingface tokenizer revision + @JsonProperty("tokenizer_revision") + @Schema( + title = "fields.huggingface.tokenizerrevision.title", + description = "fields.huggingface.tokenizerrevision.description" + ) + private String tokenizerRevision; + + // Huggingface max sequence length for the tokenizer + @JsonProperty("max_length") + @Schema(title = "fields.huggingface.maxlength.title", description = "fields.huggingface.maxlength.description") + private Integer maxLength; + + // do not use lower case for the tokenizer + @JsonProperty("disable_lower_case") + @Schema( + title = "fields.huggingface.disablelowercase.title", + description = "fields.huggingface.disablelowercase.description" + ) + private Boolean disableLowerCase; + + // the sequences will not be encoded with the special tokens relative to their model + @JsonProperty("disable_special_tokens") + @Schema( + title = "fields.huggingface.disablespecialtokens.title", + description = "fields.huggingface.disablespecialtokens.description" + ) + private Boolean disableSpecialTokens; + + // data type to load the weights in. + @JsonProperty("dtype") + @Schema(title = "fields.huggingface.dtype.title", description = "fields.huggingface.dtype.description") + private HuggingfaceDType dtype; + + // allow loading of models and tokenizers with custom code + @JsonProperty("trust_remote_code") + @Schema( + title = "fields.huggingface.trustremotecode.title", + description = "fields.huggingface.trustremotecode.description" + ) + private Boolean trustRemoteCode; + + // the tensor input names passed to the model + @JsonProperty("tensor_input_names") + @Schema( + title = "fields.huggingface.tensorinputnames.title", + description = "fields.huggingface.tensorinputnames.description" + ) + private List tensorInputNames; + + // Return token type ids + @JsonProperty("return_token_type_ids") + @Schema( + title = "fields.huggingface.returntokentypeids.title", + description = "fields.huggingface.returntokentypeids.description" + ) + private Boolean returnTokenTypeIds; + + // Return all probabilities + @JsonProperty("return_probabilities") + @Schema( + title = "fields.huggingface.returnprobabilities.title", + description = "fields.huggingface.returnprobabilities.description" + ) + private Boolean returnProbabilities; + + // Disable logging requests + @JsonProperty("disable_log_requests") + @Schema( + title = "fields.huggingface.disablelogrequests.title", + description = "fields.huggingface.disablelogrequests.description" + ) + private Boolean disableLogRequests; + + // Max number of prompt characters or prompt + @JsonProperty("max_log_len") + @Schema(title = "fields.huggingface.maxloglen.title", description = "fields.huggingface.maxloglen.description") + private Integer maxLogLen; + + @Override + public void configure(Map data) { + super.configure(data); + HuggingfaceServeTaskSpec spec = mapper.convertValue(data, HuggingfaceServeTaskSpec.class); + this.huggingfaceTask = spec.getHuggingfaceTask(); + + this.backend = spec.getBackend(); + this.tokenizerRevision = spec.getTokenizerRevision(); + this.maxLength = spec.getMaxLength(); + this.disableLowerCase = spec.getDisableLowerCase(); + this.disableSpecialTokens = spec.getDisableSpecialTokens(); + this.dtype = spec.getDtype(); + this.trustRemoteCode = spec.getTrustRemoteCode(); + this.tensorInputNames = spec.getTensorInputNames(); + this.returnTokenTypeIds = spec.getReturnTokenTypeIds(); + this.returnProbabilities = spec.getReturnProbabilities(); + this.disableLogRequests = spec.getDisableLogRequests(); + this.maxLogLen = spec.getMaxLogLen(); + } + + public static HuggingfaceServeTaskSpec with(Map data) { + HuggingfaceServeTaskSpec spec = new HuggingfaceServeTaskSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/MlflowServeRunner.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/MlflowServeRunner.java new file mode 100644 index 00000000..86ba2e53 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/MlflowServeRunner.java @@ -0,0 +1,154 @@ +package it.smartcommunitylabdhub.runtime.mlflow; + +import it.smartcommunitylabdhub.commons.accessors.spec.TaskSpecAccessor; +import it.smartcommunitylabdhub.commons.exceptions.CoreRuntimeException; +import it.smartcommunitylabdhub.commons.infrastructure.Runner; +import it.smartcommunitylabdhub.commons.jackson.JacksonMapper; +import it.smartcommunitylabdhub.commons.models.entities.run.Run; +import it.smartcommunitylabdhub.commons.models.enums.State; +import it.smartcommunitylabdhub.commons.models.utils.TaskUtils; +import it.smartcommunitylabdhub.framework.k8s.kubernetes.K8sBuilderHelper; +import it.smartcommunitylabdhub.framework.k8s.model.ContextRef; +import it.smartcommunitylabdhub.framework.k8s.model.ContextSource; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreEnv; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreLabel; +import it.smartcommunitylabdhub.framework.k8s.objects.CorePort; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreServiceType; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sRunnable; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sServeRunnable; +import it.smartcommunitylabdhub.runtime.mlflow.models.MLFlowSettingsParameters; +import it.smartcommunitylabdhub.runtime.mlflow.models.MLFlowSettingsSpec; +import it.smartcommunitylabdhub.runtime.mlflow.specs.MlflowServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.mlflow.specs.MlflowServeRunSpec; +import it.smartcommunitylabdhub.runtime.mlflow.specs.MlflowServeTaskSpec; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeServeTaskSpec; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +public class MlflowServeRunner implements Runner { + + private static final int HTTP_PORT = 8080; + private static final int GRPC_PORT = 8081; + + private final String image; + private final MlflowServeFunctionSpec functionSpec; + private final Map> groupedSecrets; + + private final K8sBuilderHelper k8sBuilderHelper; + + public MlflowServeRunner( + String image, + MlflowServeFunctionSpec functionSpec, + Map> groupedSecrets, + K8sBuilderHelper k8sBuilderHelper + ) { + this.image = image; + this.functionSpec = functionSpec; + this.groupedSecrets = groupedSecrets; + this.k8sBuilderHelper = k8sBuilderHelper; + } + + @Override + public K8sRunnable produce(Run run) { + MlflowServeRunSpec runSpec = MlflowServeRunSpec.with(run.getSpec()); + ModelServeServeTaskSpec taskSpec = runSpec.getTaskServeSpec(); + TaskSpecAccessor taskAccessor = TaskUtils.parseFunction(taskSpec.getFunction()); + + List coreEnvList = new ArrayList<>( + List.of(new CoreEnv("PROJECT_NAME", run.getProject()), new CoreEnv("RUN_ID", run.getId())) + ); + + Optional.ofNullable(taskSpec.getEnvs()).ifPresent(coreEnvList::addAll); + + UriComponents uri = UriComponentsBuilder.fromUriString(functionSpec.getPath()).build(); + String source = functionSpec.getPath().trim(); + if (!source.endsWith("/")) source += "/"; + + //read source and build context + List contextRefs = Collections.singletonList( + ContextRef.builder().source(source).protocol(uri.getScheme()).destination("model").build() + ); + List contextSources = new ArrayList<>(); + + MLFlowSettingsSpec mlFlowSettingsSpec = MLFlowSettingsSpec + .builder() + .name(StringUtils.hasText(functionSpec.getModelName()) ? functionSpec.getModelName() : "model") + .implementation("mlserver_mlflow.MLflowRuntime") + // .platform() + .parameters( + MLFlowSettingsParameters + .builder() + .uri("./model") + // .contentType() + .build() + ) + .build(); + + //write model settings + try { + String setttingsString = JacksonMapper.CUSTOM_OBJECT_MAPPER.writeValueAsString(mlFlowSettingsSpec); + ContextSource entry = ContextSource + .builder() + .name("model-settings.json") + .base64(Base64.getEncoder().encodeToString(setttingsString.getBytes())) + .build(); + contextSources.add(entry); + } catch (IOException ioe) { + throw new CoreRuntimeException("error with reading entrypoint for runtime-mlflow"); + } + + List args = new ArrayList<>(List.of("start", "/shared")); + + CorePort servicePort = new CorePort(HTTP_PORT, HTTP_PORT); + CorePort grpcPort = new CorePort(GRPC_PORT, GRPC_PORT); + + String img = StringUtils.hasText(functionSpec.getImage()) ? functionSpec.getImage() : image; + + //build runnable + K8sRunnable k8sServeRunnable = K8sServeRunnable + .builder() + .runtime(MlflowServeRuntime.RUNTIME) + .task(MlflowServeTaskSpec.KIND) + .state(State.READY.name()) + .labels( + k8sBuilderHelper != null + ? List.of(new CoreLabel(k8sBuilderHelper.getLabelName("function"), taskAccessor.getFunction())) + : null + ) + //base + .image(img) + .command("mlserver") + .args(args.toArray(new String[0])) + .contextSources(contextSources) + .contextRefs(contextRefs) + .envs(coreEnvList) + .secrets(groupedSecrets) + .resources(taskSpec.getResources()) + .volumes(taskSpec.getVolumes()) + .nodeSelector(taskSpec.getNodeSelector()) + .affinity(taskSpec.getAffinity()) + .tolerations(taskSpec.getTolerations()) + .runtimeClass(taskSpec.getRuntimeClass()) + .priorityClass(taskSpec.getPriorityClass()) + .template(taskSpec.getProfile()) + //specific + .replicas(taskSpec.getReplicas()) + .servicePorts(List.of(servicePort, grpcPort)) + .serviceType(taskSpec.getServiceType() != null ? taskSpec.getServiceType() : CoreServiceType.NodePort) + .build(); + + k8sServeRunnable.setId(run.getId()); + k8sServeRunnable.setProject(run.getProject()); + + return k8sServeRunnable; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/MlflowServeRuntime.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/MlflowServeRuntime.java new file mode 100644 index 00000000..134b09fc --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/MlflowServeRuntime.java @@ -0,0 +1,110 @@ +package it.smartcommunitylabdhub.runtime.mlflow; + +import it.smartcommunitylabdhub.commons.accessors.spec.RunSpecAccessor; +import it.smartcommunitylabdhub.commons.annotations.infrastructure.RuntimeComponent; +import it.smartcommunitylabdhub.commons.models.base.Executable; +import it.smartcommunitylabdhub.commons.models.entities.run.Run; +import it.smartcommunitylabdhub.commons.models.entities.task.Task; +import it.smartcommunitylabdhub.commons.models.entities.task.TaskBaseSpec; +import it.smartcommunitylabdhub.commons.models.utils.RunUtils; +import it.smartcommunitylabdhub.commons.services.entities.SecretService; +import it.smartcommunitylabdhub.framework.k8s.base.K8sBaseRuntime; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sRunnable; +import it.smartcommunitylabdhub.runtime.mlflow.specs.MlflowServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.mlflow.specs.MlflowServeRunSpec; +import it.smartcommunitylabdhub.runtime.mlflow.specs.MlflowServeTaskSpec; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeRunStatus; +import jakarta.validation.constraints.NotNull; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; + +@Slf4j +@RuntimeComponent(runtime = MlflowServeRuntime.RUNTIME) +public class MlflowServeRuntime + extends K8sBaseRuntime { + + public static final String RUNTIME = "mlflowserve"; + + @Autowired + private SecretService secretService; + + @Value("${runtime.mlflowserve.image}") + private String image; + + public MlflowServeRuntime() { + super(MlflowServeRunSpec.KIND); + } + + @Override + public MlflowServeRunSpec build(@NotNull Executable function, @NotNull Task task, @NotNull Run run) { + //check run kind + if (!MlflowServeRunSpec.KIND.equals(run.getKind())) { + throw new IllegalArgumentException( + "Run kind {} unsupported, expecting {}".formatted( + String.valueOf(run.getKind()), + MlflowServeRunSpec.KIND + ) + ); + } + + MlflowServeFunctionSpec funSpec = MlflowServeFunctionSpec.with(function.getSpec()); + MlflowServeRunSpec runSpec = MlflowServeRunSpec.with(run.getSpec()); + + String kind = task.getKind(); + + //build task spec as defined + TaskBaseSpec taskSpec = + switch (kind) { + case MlflowServeTaskSpec.KIND -> { + yield MlflowServeTaskSpec.with(task.getSpec()); + } + default -> throw new IllegalArgumentException( + "Kind not recognized. Cannot retrieve the right builder or specialize Spec for Run and Task." + ); + }; + + //build run merging task spec overrides + Map map = new HashMap<>(); + map.putAll(runSpec.toMap()); + taskSpec.toMap().forEach(map::putIfAbsent); + + MlflowServeRunSpec serveSpec = MlflowServeRunSpec.with(map); + //ensure function is not modified + serveSpec.setFunctionSpec(funSpec); + + return serveSpec; + } + + @Override + public K8sRunnable run(@NotNull Run run) { + //check run kind + if (!MlflowServeRunSpec.KIND.equals(run.getKind())) { + throw new IllegalArgumentException( + "Run kind {} unsupported, expecting {}".formatted( + String.valueOf(run.getKind()), + MlflowServeRunSpec.KIND + ) + ); + } + + MlflowServeRunSpec runSpec = MlflowServeRunSpec.with(run.getSpec()); + + // Create string run accessor from task + RunSpecAccessor runAccessor = RunUtils.parseTask(runSpec.getTask()); + + return switch (runAccessor.getTask()) { + case MlflowServeTaskSpec.KIND -> new MlflowServeRunner( + image, + runSpec.getFunctionSpec(), + secretService.groupSecrets(run.getProject(), runSpec.getTaskServeSpec().getSecrets()), + k8sBuilderHelper + ) + .produce(run); + default -> throw new IllegalArgumentException("Kind not recognized. Cannot retrieve the right Runner"); + }; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/models/MLFlowSettingsParameters.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/models/MLFlowSettingsParameters.java new file mode 100644 index 00000000..7434c96d --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/models/MLFlowSettingsParameters.java @@ -0,0 +1,29 @@ +package it.smartcommunitylabdhub.runtime.mlflow.models; + +import com.fasterxml.jackson.annotation.JsonProperty; +import java.io.Serializable; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.ToString; + +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +@Builder +@ToString +public class MLFlowSettingsParameters { + + private String uri; + + @JsonProperty("content_type") + private String contentType; + + private String version; + private String format; + private Map extra; +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/models/MLFlowSettingsSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/models/MLFlowSettingsSpec.java new file mode 100644 index 00000000..df4c64ef --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/models/MLFlowSettingsSpec.java @@ -0,0 +1,22 @@ +package it.smartcommunitylabdhub.runtime.mlflow.models; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.ToString; + +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +@Builder +@ToString +public class MLFlowSettingsSpec { + + private String name; + private String implementation; + private MLFlowSettingsParameters parameters; + private String platform; +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeFunctionSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeFunctionSpec.java new file mode 100644 index 00000000..727c56b8 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeFunctionSpec.java @@ -0,0 +1,24 @@ +package it.smartcommunitylabdhub.runtime.mlflow.specs; + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.mlflow.MlflowServeRuntime; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeFunctionSpec; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = MlflowServeRuntime.RUNTIME, kind = MlflowServeRuntime.RUNTIME, entity = EntityName.FUNCTION) +public class MlflowServeFunctionSpec extends ModelServeFunctionSpec { + + public static MlflowServeFunctionSpec with(Map data) { + MlflowServeFunctionSpec spec = new MlflowServeFunctionSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeRunSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeRunSpec.java new file mode 100644 index 00000000..44647ed9 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeRunSpec.java @@ -0,0 +1,49 @@ +package it.smartcommunitylabdhub.runtime.mlflow.specs; + +import com.fasterxml.jackson.annotation.JsonUnwrapped; +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.entities.run.RunBaseSpec; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.mlflow.MlflowServeRuntime; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = MlflowServeRuntime.RUNTIME, kind = MlflowServeRunSpec.KIND, entity = EntityName.RUN) +public class MlflowServeRunSpec extends RunBaseSpec { + + public static final String KIND = MlflowServeRuntime.RUNTIME + "+run"; + + @JsonUnwrapped + private MlflowServeFunctionSpec functionSpec; + + @JsonUnwrapped + private MlflowServeTaskSpec taskServeSpec; + + @Override + public void configure(Map data) { + super.configure(data); + MlflowServeRunSpec spec = mapper.convertValue(data, MlflowServeRunSpec.class); + this.functionSpec = spec.getFunctionSpec(); + this.taskServeSpec = spec.getTaskServeSpec(); + } + + public void setFunctionSpec(MlflowServeFunctionSpec functionSpec) { + this.functionSpec = functionSpec; + } + + public void setTaskServeSpec(MlflowServeTaskSpec taskServeSpec) { + this.taskServeSpec = taskServeSpec; + } + + public static MlflowServeRunSpec with(Map data) { + MlflowServeRunSpec spec = new MlflowServeRunSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeTaskSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeTaskSpec.java new file mode 100644 index 00000000..84017628 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/mlflow/specs/MlflowServeTaskSpec.java @@ -0,0 +1,26 @@ +package it.smartcommunitylabdhub.runtime.mlflow.specs; + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.mlflow.MlflowServeRuntime; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeServeTaskSpec; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = MlflowServeRuntime.RUNTIME, kind = MlflowServeTaskSpec.KIND, entity = EntityName.TASK) +public class MlflowServeTaskSpec extends ModelServeServeTaskSpec { + + public static final String KIND = MlflowServeRuntime.RUNTIME + "+serve"; + + public static MlflowServeTaskSpec with(Map data) { + MlflowServeTaskSpec spec = new MlflowServeTaskSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeFunctionSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeFunctionSpec.java new file mode 100644 index 00000000..4b3c0cab --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeFunctionSpec.java @@ -0,0 +1,43 @@ +package it.smartcommunitylabdhub.runtime.modelserve.specs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; +import it.smartcommunitylabdhub.commons.models.entities.function.FunctionBaseSpec; +import jakarta.validation.constraints.NotNull; +import java.io.Serializable; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +public class ModelServeFunctionSpec extends FunctionBaseSpec { + + @JsonProperty("path") + @NotNull + @Schema(title = "fields.modelserve.path.title", description = "fields.modelserve.path.description") + private String path; + + @JsonProperty("model_name") + @Schema(title = "fields.modelserve.modelname.title", description = "fields.modelserve.modelname.description") + private String modelName; + + @JsonProperty("image") + @Schema(title = "fields.container.image.title", description = "fields.container.image.description") + private String image; + + @Override + public void configure(Map data) { + super.configure(data); + + ModelServeFunctionSpec spec = mapper.convertValue(data, ModelServeFunctionSpec.class); + + this.modelName = spec.getModelName(); + this.path = spec.getPath(); + this.image = spec.getImage(); + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeRunStatus.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeRunStatus.java new file mode 100644 index 00000000..0ea80120 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeRunStatus.java @@ -0,0 +1,15 @@ +package it.smartcommunitylabdhub.runtime.modelserve.specs; + +import com.fasterxml.jackson.annotation.JsonInclude; +import it.smartcommunitylabdhub.commons.models.entities.run.RunBaseStatus; +import lombok.Builder; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@Builder +@JsonInclude(JsonInclude.Include.NON_NULL) +public class ModelServeRunStatus extends RunBaseStatus {} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeServeTaskSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeServeTaskSpec.java new file mode 100644 index 00000000..b5d51a39 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/modelserve/specs/ModelServeServeTaskSpec.java @@ -0,0 +1,40 @@ +package it.smartcommunitylabdhub.runtime.modelserve.specs; + +import com.fasterxml.jackson.annotation.JsonProperty; +import io.swagger.v3.oas.annotations.media.Schema; +import it.smartcommunitylabdhub.framework.k8s.base.K8sTaskBaseSpec; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreServiceType; +import jakarta.validation.constraints.Min; +import java.io.Serializable; +import java.util.Map; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@AllArgsConstructor +public class ModelServeServeTaskSpec extends K8sTaskBaseSpec { + + @JsonProperty("replicas") + @Min(0) + private Integer replicas; + + // ClusterIP or NodePort + @JsonProperty(value = "service_type", defaultValue = "NodePort") + @Schema(defaultValue = "NodePort") + private CoreServiceType serviceType; + + @Override + public void configure(Map data) { + super.configure(data); + + ModelServeServeTaskSpec spec = mapper.convertValue(data, ModelServeServeTaskSpec.class); + + this.replicas = spec.getReplicas(); + + this.setServiceType(spec.getServiceType()); + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/SklearnServeRunner.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/SklearnServeRunner.java new file mode 100644 index 00000000..9f96cafd --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/SklearnServeRunner.java @@ -0,0 +1,130 @@ +package it.smartcommunitylabdhub.runtime.sklearn; + +import it.smartcommunitylabdhub.commons.accessors.spec.TaskSpecAccessor; +import it.smartcommunitylabdhub.commons.infrastructure.Runner; +import it.smartcommunitylabdhub.commons.models.entities.run.Run; +import it.smartcommunitylabdhub.commons.models.enums.State; +import it.smartcommunitylabdhub.commons.models.utils.TaskUtils; +import it.smartcommunitylabdhub.framework.k8s.kubernetes.K8sBuilderHelper; +import it.smartcommunitylabdhub.framework.k8s.model.ContextRef; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreEnv; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreLabel; +import it.smartcommunitylabdhub.framework.k8s.objects.CorePort; +import it.smartcommunitylabdhub.framework.k8s.objects.CoreServiceType; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sRunnable; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sServeRunnable; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeServeTaskSpec; +import it.smartcommunitylabdhub.runtime.sklearn.specs.SklearnServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.sklearn.specs.SklearnServeRunSpec; +import it.smartcommunitylabdhub.runtime.sklearn.specs.SklearnServeTaskSpec; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.springframework.util.StringUtils; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; + +public class SklearnServeRunner implements Runner { + + private static final int HTTP_PORT = 8080; + private static final int GRPC_PORT = 8081; + + private final String image; + private final ModelServeFunctionSpec functionSpec; + private final Map> groupedSecrets; + + private final K8sBuilderHelper k8sBuilderHelper; + + public SklearnServeRunner( + String image, + SklearnServeFunctionSpec functionSpec, + Map> groupedSecrets, + K8sBuilderHelper k8sBuilderHelper + ) { + this.image = image; + this.functionSpec = functionSpec; + this.groupedSecrets = groupedSecrets; + this.k8sBuilderHelper = k8sBuilderHelper; + } + + @Override + public K8sRunnable produce(Run run) { + SklearnServeRunSpec runSpec = SklearnServeRunSpec.with(run.getSpec()); + ModelServeServeTaskSpec taskSpec = runSpec.getTaskServeSpec(); + TaskSpecAccessor taskAccessor = TaskUtils.parseFunction(taskSpec.getFunction()); + + List coreEnvList = new ArrayList<>( + List.of(new CoreEnv("PROJECT_NAME", run.getProject()), new CoreEnv("RUN_ID", run.getId())) + ); + + Optional.ofNullable(taskSpec.getEnvs()).ifPresent(coreEnvList::addAll); + + UriComponents uri = UriComponentsBuilder.fromUriString(functionSpec.getPath()).build(); + + //read source and build context + List contextRefs = Collections.singletonList( + ContextRef.builder().source(functionSpec.getPath()).protocol(uri.getScheme()).destination("model").build() + ); + + List args = new ArrayList<>( + List.of( + "-m", + "sklearnserver", + "--model_dir", + "/shared/model", + "--model_name", + StringUtils.hasText(functionSpec.getModelName()) ? functionSpec.getModelName() : "model", + "--protocol", + "v2", + "--enable_docs_url", + "true" + ) + ); + + CorePort servicePort = new CorePort(HTTP_PORT, HTTP_PORT); + CorePort grpcPort = new CorePort(GRPC_PORT, GRPC_PORT); + + String img = StringUtils.hasText(functionSpec.getImage()) ? functionSpec.getImage() : image; + + //build runnable + K8sRunnable k8sServeRunnable = K8sServeRunnable + .builder() + .runtime(SklearnServeRuntime.RUNTIME) + .task(SklearnServeTaskSpec.KIND) + .state(State.READY.name()) + .labels( + k8sBuilderHelper != null + ? List.of(new CoreLabel(k8sBuilderHelper.getLabelName("function"), taskAccessor.getFunction())) + : null + ) + //base + .image(img) + .command("python") + .args(args.toArray(new String[0])) + .contextRefs(contextRefs) + .envs(coreEnvList) + .secrets(groupedSecrets) + .resources(taskSpec.getResources()) + .volumes(taskSpec.getVolumes()) + .nodeSelector(taskSpec.getNodeSelector()) + .affinity(taskSpec.getAffinity()) + .tolerations(taskSpec.getTolerations()) + .runtimeClass(taskSpec.getRuntimeClass()) + .priorityClass(taskSpec.getPriorityClass()) + .template(taskSpec.getProfile()) + //specific + .replicas(taskSpec.getReplicas()) + .servicePorts(List.of(servicePort, grpcPort)) + .serviceType(taskSpec.getServiceType() != null ? taskSpec.getServiceType() : CoreServiceType.NodePort) + .build(); + + k8sServeRunnable.setId(run.getId()); + k8sServeRunnable.setProject(run.getProject()); + + return k8sServeRunnable; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/SklearnServeRuntime.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/SklearnServeRuntime.java new file mode 100644 index 00000000..cece785f --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/SklearnServeRuntime.java @@ -0,0 +1,110 @@ +package it.smartcommunitylabdhub.runtime.sklearn; + +import it.smartcommunitylabdhub.commons.accessors.spec.RunSpecAccessor; +import it.smartcommunitylabdhub.commons.annotations.infrastructure.RuntimeComponent; +import it.smartcommunitylabdhub.commons.models.base.Executable; +import it.smartcommunitylabdhub.commons.models.entities.run.Run; +import it.smartcommunitylabdhub.commons.models.entities.task.Task; +import it.smartcommunitylabdhub.commons.models.entities.task.TaskBaseSpec; +import it.smartcommunitylabdhub.commons.models.utils.RunUtils; +import it.smartcommunitylabdhub.commons.services.entities.SecretService; +import it.smartcommunitylabdhub.framework.k8s.base.K8sBaseRuntime; +import it.smartcommunitylabdhub.framework.k8s.runnables.K8sRunnable; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeRunStatus; +import it.smartcommunitylabdhub.runtime.sklearn.specs.SklearnServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.sklearn.specs.SklearnServeRunSpec; +import it.smartcommunitylabdhub.runtime.sklearn.specs.SklearnServeTaskSpec; +import jakarta.validation.constraints.NotNull; +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; + +@Slf4j +@RuntimeComponent(runtime = SklearnServeRuntime.RUNTIME) +public class SklearnServeRuntime + extends K8sBaseRuntime { + + public static final String RUNTIME = "sklearnserve"; + + @Autowired + private SecretService secretService; + + @Value("${runtime.sklearnserve.image}") + private String image; + + public SklearnServeRuntime() { + super(SklearnServeRunSpec.KIND); + } + + @Override + public SklearnServeRunSpec build(@NotNull Executable function, @NotNull Task task, @NotNull Run run) { + //check run kind + if (!SklearnServeRunSpec.KIND.equals(run.getKind())) { + throw new IllegalArgumentException( + "Run kind {} unsupported, expecting {}".formatted( + String.valueOf(run.getKind()), + SklearnServeRunSpec.KIND + ) + ); + } + + SklearnServeFunctionSpec funSpec = SklearnServeFunctionSpec.with(function.getSpec()); + SklearnServeRunSpec runSpec = SklearnServeRunSpec.with(run.getSpec()); + + String kind = task.getKind(); + + //build task spec as defined + TaskBaseSpec taskSpec = + switch (kind) { + case SklearnServeTaskSpec.KIND -> { + yield SklearnServeTaskSpec.with(task.getSpec()); + } + default -> throw new IllegalArgumentException( + "Kind not recognized. Cannot retrieve the right builder or specialize Spec for Run and Task." + ); + }; + + //build run merging task spec overrides + Map map = new HashMap<>(); + map.putAll(runSpec.toMap()); + taskSpec.toMap().forEach(map::putIfAbsent); + + SklearnServeRunSpec serveSpec = SklearnServeRunSpec.with(map); + //ensure function is not modified + serveSpec.setFunctionSpec(funSpec); + + return serveSpec; + } + + @Override + public K8sRunnable run(@NotNull Run run) { + //check run kind + if (!SklearnServeRunSpec.KIND.equals(run.getKind())) { + throw new IllegalArgumentException( + "Run kind {} unsupported, expecting {}".formatted( + String.valueOf(run.getKind()), + SklearnServeRunSpec.KIND + ) + ); + } + + SklearnServeRunSpec runSpec = SklearnServeRunSpec.with(run.getSpec()); + + // Create string run accessor from task + RunSpecAccessor runAccessor = RunUtils.parseTask(runSpec.getTask()); + + return switch (runAccessor.getTask()) { + case SklearnServeTaskSpec.KIND -> new SklearnServeRunner( + image, + runSpec.getFunctionSpec(), + secretService.groupSecrets(run.getProject(), runSpec.getTaskServeSpec().getSecrets()), + k8sBuilderHelper + ) + .produce(run); + default -> throw new IllegalArgumentException("Kind not recognized. Cannot retrieve the right Runner"); + }; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeFunctionSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeFunctionSpec.java new file mode 100644 index 00000000..7179e1b2 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeFunctionSpec.java @@ -0,0 +1,24 @@ +package it.smartcommunitylabdhub.runtime.sklearn.specs; + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeFunctionSpec; +import it.smartcommunitylabdhub.runtime.sklearn.SklearnServeRuntime; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = SklearnServeRuntime.RUNTIME, kind = SklearnServeRuntime.RUNTIME, entity = EntityName.FUNCTION) +public class SklearnServeFunctionSpec extends ModelServeFunctionSpec { + + public static SklearnServeFunctionSpec with(Map data) { + SklearnServeFunctionSpec spec = new SklearnServeFunctionSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeRunSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeRunSpec.java new file mode 100644 index 00000000..fb8de4f8 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeRunSpec.java @@ -0,0 +1,51 @@ +package it.smartcommunitylabdhub.runtime.sklearn.specs; + +import com.fasterxml.jackson.annotation.JsonUnwrapped; +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.entities.run.RunBaseSpec; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.sklearn.SklearnServeRuntime; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = SklearnServeRuntime.RUNTIME, kind = SklearnServeRunSpec.KIND, entity = EntityName.RUN) +public class SklearnServeRunSpec extends RunBaseSpec { + + @JsonUnwrapped + private SklearnServeFunctionSpec functionSpec; + + @JsonUnwrapped + private SklearnServeTaskSpec taskServeSpec; + + public static final String KIND = SklearnServeRuntime.RUNTIME + "+run"; + + @Override + public void configure(Map data) { + super.configure(data); + + SklearnServeRunSpec spec = mapper.convertValue(data, SklearnServeRunSpec.class); + + this.functionSpec = spec.getFunctionSpec(); + this.taskServeSpec = spec.getTaskServeSpec(); + } + + public void setFunctionSpec(SklearnServeFunctionSpec functionSpec) { + this.functionSpec = functionSpec; + } + + public void setTaskServeSpec(SklearnServeTaskSpec taskServeSpec) { + this.taskServeSpec = taskServeSpec; + } + + public static SklearnServeRunSpec with(Map data) { + SklearnServeRunSpec spec = new SklearnServeRunSpec(); + spec.configure(data); + return spec; + } +} diff --git a/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeTaskSpec.java b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeTaskSpec.java new file mode 100644 index 00000000..7c6298d4 --- /dev/null +++ b/modules/runtime-model-serve/src/main/java/it/smartcommunitylabdhub/runtime/sklearn/specs/SklearnServeTaskSpec.java @@ -0,0 +1,26 @@ +package it.smartcommunitylabdhub.runtime.sklearn.specs; + +import it.smartcommunitylabdhub.commons.annotations.common.SpecType; +import it.smartcommunitylabdhub.commons.models.enums.EntityName; +import it.smartcommunitylabdhub.runtime.modelserve.specs.ModelServeServeTaskSpec; +import it.smartcommunitylabdhub.runtime.sklearn.SklearnServeRuntime; +import java.io.Serializable; +import java.util.Map; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; + +@Getter +@Setter +@NoArgsConstructor +@SpecType(runtime = SklearnServeRuntime.RUNTIME, kind = SklearnServeTaskSpec.KIND, entity = EntityName.TASK) +public class SklearnServeTaskSpec extends ModelServeServeTaskSpec { + + public static final String KIND = SklearnServeRuntime.RUNTIME + "+serve"; + + public static SklearnServeTaskSpec with(Map data) { + SklearnServeTaskSpec spec = new SklearnServeTaskSpec(); + spec.configure(data); + return spec; + } +} diff --git a/pom.xml b/pom.xml index 618232d3..5484f85e 100644 --- a/pom.xml +++ b/pom.xml @@ -37,6 +37,7 @@ modules/runtime-mlrun modules/runtime-kfp modules/runtime-python + modules/runtime-model-serve modules/openmetadata-integration modules/files frontend