Skip to content

Commit

Permalink
Merge pull request #93 from scc-digitalhub/mlrun-build
Browse files Browse the repository at this point in the history
Mlrun-build
  • Loading branch information
kazhamiakin authored Apr 29, 2024
2 parents 8e4de68 + 53a049b commit 5c19746
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 11 deletions.
7 changes: 7 additions & 0 deletions application/src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ kaniko:
secret: ${KANIKO_SECRET:}
args: ${KANIKO_ARGS:}

# MLRun config
mlrun:
base-image: ${MLRUN_BASE_IMAGE:mlrun/mlrun}
image-prefix: ${MLRUN_IMAGE_PREFIX:dhcore}
image-registry: ${MLRUN_IMAGE_REGISTRY:}


# Kubernetes
kubernetes:
namespace: ${K8S_NAMESPACE:default}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package it.smartcommunitylabdhub.runtime.mlrun;

import it.smartcommunitylabdhub.commons.accessors.spec.RunSpecAccessor;
import it.smartcommunitylabdhub.commons.accessors.spec.TaskSpecAccessor;
import it.smartcommunitylabdhub.commons.annotations.infrastructure.RuntimeComponent;
import it.smartcommunitylabdhub.commons.exceptions.CoreRuntimeException;
import it.smartcommunitylabdhub.commons.exceptions.NoSuchEntityException;
Expand All @@ -12,28 +13,35 @@
import it.smartcommunitylabdhub.commons.models.entities.task.Task;
import it.smartcommunitylabdhub.commons.models.enums.State;
import it.smartcommunitylabdhub.commons.models.utils.RunUtils;
import it.smartcommunitylabdhub.commons.models.utils.TaskUtils;
import it.smartcommunitylabdhub.commons.services.RunnableStore;
import it.smartcommunitylabdhub.commons.services.entities.FunctionService;
import it.smartcommunitylabdhub.commons.services.entities.SecretService;
import it.smartcommunitylabdhub.framework.k8s.base.K8sTaskSpec;
import it.smartcommunitylabdhub.framework.k8s.runnables.K8sJobRunnable;
import it.smartcommunitylabdhub.runtime.mlrun.builders.MlrunMlrunBuilder;
import it.smartcommunitylabdhub.runtime.mlrun.runners.MlrunMlrunRunner;
import it.smartcommunitylabdhub.runtime.mlrun.builders.MlrunBuildBuilder;
import it.smartcommunitylabdhub.runtime.mlrun.builders.MlrunJobBuilder;
import it.smartcommunitylabdhub.runtime.mlrun.runners.MlrunBuildRunner;
import it.smartcommunitylabdhub.runtime.mlrun.runners.MlrunJobRunner;
import it.smartcommunitylabdhub.runtime.mlrun.specs.function.FunctionMlrunSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.run.RunMlrunSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.task.TaskMlrunBuildSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.task.TaskMlrunJobSpec;
import it.smartcommunitylabdhub.runtime.mlrun.status.RunMlrunStatus;
import jakarta.validation.constraints.NotNull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.util.StringUtils;

@RuntimeComponent(runtime = MlrunRuntime.RUNTIME)
@Slf4j
public class MlrunRuntime implements Runtime<FunctionMlrunSpec, RunMlrunSpec, RunMlrunStatus, K8sJobRunnable> {

public static final String RUNTIME = "mlrun";

private final MlrunMlrunBuilder builder = new MlrunMlrunBuilder();
private final MlrunJobBuilder jobBuilder = new MlrunJobBuilder();
private final MlrunBuildBuilder buildBuilder = new MlrunBuildBuilder();

@Autowired
SecretService secretService;
Expand All @@ -44,6 +52,15 @@ public class MlrunRuntime implements Runtime<FunctionMlrunSpec, RunMlrunSpec, Ru
@Value("${runtime.mlrun.image}")
private String image;

@Value("${mlrun.base-image}")
private String baseImage;

@Value("${mlrun.image-prefix}")
private String imagePrefix;

@Value("${mlrun.image-registry:}")
private String imageRegistry;

@Override
public RunMlrunSpec build(@NotNull Executable function, @NotNull Task task, @NotNull Run run) {
//check run kind
Expand All @@ -61,14 +78,23 @@ public RunMlrunSpec build(@NotNull Executable function, @NotNull Task task, @Not
return switch (kind) {
case TaskMlrunJobSpec.KIND -> {
TaskMlrunJobSpec taskMlrunSpec = new TaskMlrunJobSpec(task.getSpec());
yield builder.build(functionSpec, taskMlrunSpec, runSpec);
yield jobBuilder.build(functionSpec, taskMlrunSpec, runSpec);
}
case TaskMlrunBuildSpec.KIND -> {
TaskMlrunBuildSpec taskMlrunSpec = new TaskMlrunBuildSpec(task.getSpec());
taskMlrunSpec.setTargetImage(createTargetImage(function.getName(), function.getId()));
yield buildBuilder.build(functionSpec, taskMlrunSpec, runSpec);
}
default -> throw new IllegalArgumentException(
"Kind not recognized. Cannot retrieve the right builder or specialize Spec for Run and Task."
);
};
}

private String createTargetImage(String name, String id) {
return (StringUtils.hasText(imageRegistry) ? imageRegistry + "/" : "") + imagePrefix + "-" + name + ":" + id;
}

@Override
public K8sJobRunnable run(Run run) {
//check run kind
Expand All @@ -92,7 +118,20 @@ public K8sJobRunnable run(Run run) {
}
K8sTaskSpec k8s = taskSpec.getK8s() != null ? taskSpec.getK8s() : new K8sTaskSpec();

yield new MlrunMlrunRunner(image, secretService.groupSecrets(run.getProject(), k8s.getSecrets()))
yield new MlrunJobRunner(image, secretService.groupSecrets(run.getProject(), k8s.getSecrets()))
.produce(run);
}
case TaskMlrunBuildSpec.KIND -> {
TaskMlrunBuildSpec taskSpec = runSpec.getBuildSpec();
TaskSpecAccessor accessor = TaskUtils.parseFunction(taskSpec.getFunction());
taskSpec.setTargetImage(createTargetImage(accessor.getFunction(), accessor.getVersion()));

if (taskSpec == null) {
throw new CoreRuntimeException("null or empty task definition");
}
K8sTaskSpec k8s = taskSpec.getK8s() != null ? taskSpec.getK8s() : new K8sTaskSpec();

yield new MlrunBuildRunner(image, secretService.groupSecrets(run.getProject(), k8s.getSecrets()))
.produce(run);
}
default -> throw new IllegalArgumentException("Kind not recognized. Cannot retrieve the right Runner");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package it.smartcommunitylabdhub.runtime.mlrun.builders;

import it.smartcommunitylabdhub.commons.infrastructure.Builder;
import it.smartcommunitylabdhub.runtime.mlrun.specs.function.FunctionMlrunSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.run.RunMlrunSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.task.TaskMlrunBuildSpec;
import java.util.Optional;

/**
* MlrunMlrunBuilder
* <p>
* You can use this as a simple class or as a registered bean. If you want to retrieve this as bean from BuilderFactory
* you have to register it using the following annotation:
*
* @BuilderComponent(runtime = "mlrun", task = "build")
*/

public class MlrunBuildBuilder implements Builder<FunctionMlrunSpec, TaskMlrunBuildSpec, RunMlrunSpec> {

@Override
public RunMlrunSpec build(FunctionMlrunSpec funSpec, TaskMlrunBuildSpec taskSpec, RunMlrunSpec runSpec) {
RunMlrunSpec runMlrunSpec = new RunMlrunSpec(runSpec.toMap());
runMlrunSpec.setBuildSpec(taskSpec);
runMlrunSpec.setFuncSpec(funSpec);

//let run override k8s specs
Optional.ofNullable(runSpec.getJobSpec()).ifPresent(k8sSpec -> runSpec.getJobSpec().configure(k8sSpec.toMap()));

// Return a run spec
return runMlrunSpec;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
* You can use this as a simple class or as a registered bean. If you want to retrieve this as bean from BuilderFactory
* you have to register it using the following annotation:
*
* @BuilderComponent(runtime = "mlrun", task = "mlrun")
* @BuilderComponent(runtime = "mlrun", task = "job")
*/

public class MlrunMlrunBuilder implements Builder<FunctionMlrunSpec, TaskMlrunJobSpec, RunMlrunSpec> {
public class MlrunJobBuilder implements Builder<FunctionMlrunSpec, TaskMlrunJobSpec, RunMlrunSpec> {

@Override
public RunMlrunSpec build(FunctionMlrunSpec funSpec, TaskMlrunJobSpec taskSpec, RunMlrunSpec runSpec) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package it.smartcommunitylabdhub.runtime.mlrun.runners;

import it.smartcommunitylabdhub.commons.accessors.fields.StatusFieldAccessor;
import it.smartcommunitylabdhub.commons.exceptions.CoreRuntimeException;
import it.smartcommunitylabdhub.commons.infrastructure.Runner;
import it.smartcommunitylabdhub.commons.models.entities.run.Run;
import it.smartcommunitylabdhub.commons.models.enums.State;
import it.smartcommunitylabdhub.framework.k8s.base.K8sTaskSpec;
import it.smartcommunitylabdhub.framework.k8s.objects.CoreEnv;
import it.smartcommunitylabdhub.framework.k8s.runnables.K8sJobRunnable;
import it.smartcommunitylabdhub.runtime.mlrun.MlrunRuntime;
import it.smartcommunitylabdhub.runtime.mlrun.specs.run.RunMlrunSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.task.TaskMlrunBuildSpec;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
* MlrunMlrunRunner
* <p>
* You can use this as a simple class or as a registered bean. If you want to retrieve this as bean from RunnerFactory
* you have to register it using the following annotation:
*
* @RunnerComponent(runtime = "mlrun", task = "build")
*/
public class MlrunBuildRunner implements Runner<K8sJobRunnable> {

private static final String TASK = "build";
private final String image;
private final Map<String, Set<String>> groupedSecrets;

public MlrunBuildRunner(String image, Map<String, Set<String>> groupedSecrets) {
this.image = image;
this.groupedSecrets = groupedSecrets;
}

@Override
public K8sJobRunnable produce(Run run) {
// Retrieve information about RunMlrunSpec
RunMlrunSpec runSpec = new RunMlrunSpec(run.getSpec());
TaskMlrunBuildSpec taskSpec = runSpec.getBuildSpec();
if (taskSpec == null) {
throw new CoreRuntimeException("null or empty task definition");
}

StatusFieldAccessor statusFieldAccessor = StatusFieldAccessor.with(run.getStatus());
K8sTaskSpec k8s = taskSpec.getK8s() != null ? taskSpec.getK8s() : new K8sTaskSpec();

List<CoreEnv> coreEnvList = new ArrayList<>(
List.of(new CoreEnv("PROJECT_NAME", run.getProject()), new CoreEnv("RUN_ID", run.getId()))
);

Optional.ofNullable(k8s.getEnvs()).ifPresent(coreEnvList::addAll);

//TODO: Create runnable using information from Run completed spec.
K8sJobRunnable k8sJobRunnable = K8sJobRunnable
.builder()
.runtime(MlrunRuntime.RUNTIME)
.task(TASK)
.image(image)
.command("python")
.args(List.of("wrapper.py").toArray(String[]::new))
.resources(k8s.getResources())
.nodeSelector(k8s.getNodeSelector())
.volumes(k8s.getVolumes())
.secrets(groupedSecrets)
.envs(coreEnvList)
.labels(k8s.getLabels())
.affinity(k8s.getAffinity())
.tolerations(k8s.getTolerations())
.state(State.READY.name())
.build();

k8sJobRunnable.setId(run.getId());
k8sJobRunnable.setProject(run.getProject());

return k8sJobRunnable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
* You can use this as a simple class or as a registered bean. If you want to retrieve this as bean from RunnerFactory
* you have to register it using the following annotation:
*
* @RunnerComponent(runtime = "mlrun", task = "mlrun")
* @RunnerComponent(runtime = "mlrun", task = "job")
*/
public class MlrunMlrunRunner implements Runner<K8sJobRunnable> {
public class MlrunJobRunner implements Runner<K8sJobRunnable> {

private static final String TASK = "mlrun";
private static final String TASK = "job";
private final String image;
private final Map<String, Set<String>> groupedSecrets;

public MlrunMlrunRunner(String image, Map<String, Set<String>> groupedSecrets) {
public MlrunJobRunner(String image, Map<String, Set<String>> groupedSecrets) {
this.image = image;
this.groupedSecrets = groupedSecrets;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import it.smartcommunitylabdhub.commons.models.enums.EntityName;
import it.smartcommunitylabdhub.runtime.mlrun.MlrunRuntime;
import it.smartcommunitylabdhub.runtime.mlrun.specs.function.FunctionMlrunSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.task.TaskMlrunBuildSpec;
import it.smartcommunitylabdhub.runtime.mlrun.specs.task.TaskMlrunJobSpec;
import java.io.Serializable;
import java.util.HashMap;
Expand All @@ -32,6 +33,8 @@ public class RunMlrunSpec extends RunBaseSpec {

@JsonProperty("job_spec")
private TaskMlrunJobSpec jobSpec;
@JsonProperty("build_spec")
private TaskMlrunBuildSpec buildSpec;

@JsonProperty("function_spec")
private FunctionMlrunSpec funcSpec;
Expand All @@ -50,6 +53,7 @@ public void configure(Map<String, Serializable> data) {
this.parameters = spec.getParameters();

this.jobSpec = spec.getJobSpec();
this.buildSpec = spec.getBuildSpec();
this.funcSpec = spec.getFuncSpec();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package it.smartcommunitylabdhub.runtime.mlrun.specs.task;

import it.smartcommunitylabdhub.commons.annotations.common.SpecType;
import it.smartcommunitylabdhub.commons.models.entities.task.TaskBaseSpec;
import it.smartcommunitylabdhub.commons.models.enums.EntityName;
import it.smartcommunitylabdhub.framework.k8s.base.K8sTaskSpec;
import it.smartcommunitylabdhub.runtime.mlrun.MlrunRuntime;
import java.io.Serializable;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.annotation.JsonProperty;

import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@Getter
@Setter
@NoArgsConstructor
@SpecType(runtime = MlrunRuntime.RUNTIME, kind = TaskMlrunBuildSpec.KIND, entity = EntityName.TASK)
public class TaskMlrunBuildSpec extends TaskBaseSpec {

public static final String KIND = "mlrun+build";

private K8sTaskSpec k8s = new K8sTaskSpec();

private List<String> commands;
@JsonProperty("force_build")
private Boolean forceBuild;
@JsonProperty("target_image")
private String targetImage;

public TaskMlrunBuildSpec(Map<String, Serializable> data) {
configure(data);
}

@Override
public void configure(Map<String, Serializable> data) {
super.configure(data);

TaskMlrunBuildSpec spec = mapper.convertValue(data, TaskMlrunBuildSpec.class);
this.k8s = spec.getK8s();
this.commands = spec.getCommands();
this.forceBuild = spec.getForceBuild();
this.targetImage = spec.getTargetImage();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package it.smartcommunitylabdhub.runtime.mlrun.specs.task;

import it.smartcommunitylabdhub.commons.infrastructure.SpecFactory;
import java.io.Serializable;
import java.util.Map;
import org.springframework.stereotype.Component;

@Component
public class TaskMlrunBuildSpecFactory implements SpecFactory<TaskMlrunBuildSpec> {

@Override
public TaskMlrunBuildSpec create() {
return new TaskMlrunBuildSpec();
}

@Override
public TaskMlrunBuildSpec create(Map<String, Serializable> data) {
return new TaskMlrunBuildSpec(data);
}
}

0 comments on commit 5c19746

Please sign in to comment.