From a7ffcf4d58b03ed1e7ea951939067dad7d623b82 Mon Sep 17 00:00:00 2001 From: pavlinam Date: Mon, 27 Nov 2023 15:52:48 +0100 Subject: [PATCH 1/2] Add MapTask to Flytekit --- .../main/java/org/flyte/api/v1/MapTask.java | 25 ++++++ .../main/java/org/flyte/examples/MapTask.java | 76 +++++++++++++++++++ .../java/org/flyte/flytekit/SdkMapTask.java | 68 +++++++++++++++++ .../flyte/jflyte/utils/ProjectClosure.java | 6 +- 4 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 flytekit-api/src/main/java/org/flyte/api/v1/MapTask.java create mode 100644 flytekit-examples/src/main/java/org/flyte/examples/MapTask.java create mode 100644 flytekit-java/src/main/java/org/flyte/flytekit/SdkMapTask.java diff --git a/flytekit-api/src/main/java/org/flyte/api/v1/MapTask.java b/flytekit-api/src/main/java/org/flyte/api/v1/MapTask.java new file mode 100644 index 000000000..fc7fd3436 --- /dev/null +++ b/flytekit-api/src/main/java/org/flyte/api/v1/MapTask.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.api.v1; + +public interface MapTask extends RunnableTask { + + @Override + default String getType() { + return "container_array"; + } +} diff --git a/flytekit-examples/src/main/java/org/flyte/examples/MapTask.java b/flytekit-examples/src/main/java/org/flyte/examples/MapTask.java new file mode 100644 index 000000000..39c6ee437 --- /dev/null +++ b/flytekit-examples/src/main/java/org/flyte/examples/MapTask.java @@ -0,0 +1,76 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples; + +import com.google.auto.service.AutoService; +import com.google.auto.value.AutoValue; +import java.util.List; +import java.util.stream.Collectors; +import org.flyte.flytekit.SdkBindingData; +import org.flyte.flytekit.SdkBindingDataFactory; +import org.flyte.flytekit.SdkMapTask; +import org.flyte.flytekit.SdkType; +import org.flyte.flytekit.jackson.JacksonSdkType; + +@AutoService(SdkMapTask.class) +public class MapTask extends SdkMapTask { + + /** Called by subclasses passing the {@link SdkType}s for inputs and outputs. */ + public MapTask() { + super(JacksonSdkType.of(MapTask.Input.class), JacksonSdkType.of(MapTask.Output.class)); + } + + @Override + public Output run(Input input) { + return MapTask.Output.create( + SdkBindingDataFactory.of( + input.names().get().stream() + .map(name -> name + "!") + .collect(Collectors.toList()) + .stream() + .findFirst() + .get())); + } + + @AutoValue + public abstract static class Input { + public abstract SdkBindingData> names(); + + public static MapTask.Input create(SdkBindingData> greeting) { + return new AutoValue_MapTask_Input(greeting); + } + } + + /** + * Generate an immutable value class that represents {@link GreetTask}'s output, which is a + * String. + */ + @AutoValue + public abstract static class Output { + public abstract SdkBindingData greeting(); + + /** + * Wraps the constructor of the generated output value class. + * + * @param greeting the String literal output of {@link GreetTask} + * @return output of GreetTask + */ + public static MapTask.Output create(SdkBindingData greeting) { + return new AutoValue_MapTask_Output(greeting); + } + } +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkMapTask.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkMapTask.java new file mode 100644 index 000000000..4edfc53a0 --- /dev/null +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkMapTask.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.flytekit; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.flyte.api.v1.PartialTaskIdentifier; + +public abstract class SdkMapTask extends SdkRunnableTask + implements Serializable { + + /** + * Called by subclasses passing the {@link SdkType}s for inputs and outputs. + * + * @param inputType type for inputs. + * @param outputType type for outputs. + */ + public SdkMapTask(SdkType inputType, SdkType outputType) { + super(inputType, outputType); + } + + @Override + public String getType() { + return "container_array"; + } + + @Override + public String getName() { + // TODO add something random + return getClass().getName() + "MapTask"; + } + + @Override + public SdkNode apply( + SdkWorkflowBuilder builder, + String nodeId, + List upstreamNodeIds, + @Nullable SdkNodeMetadata metadata, + Map> inputs) { + // TODO add checks for the inputs and outputs + PartialTaskIdentifier taskId = PartialTaskIdentifier.builder().name(getName()).build(); + List errors = + Compiler.validateApply(nodeId, inputs, getInputType().getVariableMap()); + + if (!errors.isEmpty()) { + throw new CompilerException(errors); + } + + return new SdkTaskNode<>( + builder, nodeId, taskId, upstreamNodeIds, metadata, inputs, this.getOutputType()); + } +} diff --git a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java index fba0e4154..61b4473e8 100644 --- a/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java +++ b/jflyte-utils/src/main/java/org/flyte/jflyte/utils/ProjectClosure.java @@ -106,6 +106,7 @@ public void serialize(BiConsumer output) { int sizeDigits = (int) (Math.log10(size) + 1); AtomicInteger counter = new AtomicInteger(); + // Serialization of the tasks taskSpecs() .forEach( (id, spec) -> { @@ -207,6 +208,7 @@ static ProjectClosure load( .build(); // 1. load classes, and create templates + // Discovering and loading all the tasks, workflows, and launch plans Map runnableTasks = ClassLoaders.withClassLoader( packageClassLoader, () -> Registrars.loadAll(RunnableTaskRegistrar.class, env)); @@ -443,7 +445,9 @@ public static Map createTaskTemplates( containerTasks.forEach( (id, task) -> { - TaskTemplate taskTemplate = createTaskTemplateForContainerTask(task); + TaskTemplate taskTemplate = + createTaskTemplateForContainerTask( + task); // container image already specified inside Contat taskTemplates.put(id, taskTemplate); }); From 1add7d315c329871b8b08f453c52c71972a65f06 Mon Sep 17 00:00:00 2001 From: Rafael Raposo Date: Tue, 28 Nov 2023 11:13:27 +0100 Subject: [PATCH 2/2] wip --- .../flytekitscala/ContainerWorkflow.scala | 44 +++++++++++++++++++ .../flytekitscala/HelloContainerTask.scala | 31 +++++++++++++ .../flytekit/SdkRunnableTaskRegistrar.java | 26 ++++++++++- 3 files changed, 100 insertions(+), 1 deletion(-) create mode 100644 flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/ContainerWorkflow.scala create mode 100644 flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/HelloContainerTask.scala diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/ContainerWorkflow.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/ContainerWorkflow.scala new file mode 100644 index 000000000..e74a48b07 --- /dev/null +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/ContainerWorkflow.scala @@ -0,0 +1,44 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples.flytekitscala + +import org.flyte.flytekitscala.{ + SdkScalaType, + SdkScalaWorkflow, + SdkScalaWorkflowBuilder +} + +class ContainerWorkflow + extends SdkScalaWorkflow[Unit, Unit]( + SdkScalaType.unit, + SdkScalaType.unit + ) { + + /** The expand method must be implement by the workflow developer. The + * workflow developer must coding the workflow logic on this method. + * + * @param builder + * The builder which is used to build the workflow DAG. + * @param input + * The workflow input. + * @return + * The workflow output. + */ + override def expand(builder: SdkScalaWorkflowBuilder, input: Unit): Unit = { + builder.apply("hello-container-task", new HelloContainerTask()) + } +} diff --git a/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/HelloContainerTask.scala b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/HelloContainerTask.scala new file mode 100644 index 000000000..0fb3751ad --- /dev/null +++ b/flytekit-examples-scala/src/main/scala/org/flyte/examples/flytekitscala/HelloContainerTask.scala @@ -0,0 +1,31 @@ +/* + * Copyright 2023 Flyte Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.flyte.examples.flytekitscala + +import org.flyte.flytekit.SdkContainerTask +import org.flyte.flytekitscala.SdkScalaType + +class HelloContainerTask + extends SdkContainerTask[Unit, Unit]( + SdkScalaType.unit, + SdkScalaType.unit + ) { + + /** Specifies container image. */ + override def getImage: String = "alpine" + +} diff --git a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTaskRegistrar.java b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTaskRegistrar.java index daa9b9c78..b619105d4 100644 --- a/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTaskRegistrar.java +++ b/flytekit-java/src/main/java/org/flyte/flytekit/SdkRunnableTaskRegistrar.java @@ -130,11 +130,14 @@ public List getCustomJavaToolOptions() { @Override @SuppressWarnings("rawtypes") public Map load(Map env, ClassLoader classLoader) { + + Map tasks = new HashMap<>(); + ServiceLoader loader = ServiceLoader.load(SdkRunnableTask.class, classLoader); + ServiceLoader mapLoader = ServiceLoader.load(SdkMapTask.class, classLoader); LOG.fine("Discovering SdkRunnableTask"); - Map tasks = new HashMap<>(); SdkConfig sdkConfig = SdkConfig.load(env); for (SdkRunnableTask sdkTask : loader) { @@ -156,6 +159,27 @@ public Map load(Map env, ClassLoad String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous)); } } + /// ------ + + for (SdkRunnableTask sdkTask : mapLoader) { + String name = sdkTask.getName(); + TaskIdentifier taskId = + TaskIdentifier.builder() + .domain(sdkConfig.domain()) + .project(sdkConfig.project()) + .name(name) + .version(sdkConfig.version()) + .build(); + LOG.fine(String.format("Discovered [%s]", name)); + + RunnableTask task = new RunnableTaskImpl<>(sdkTask); + RunnableTask previous = tasks.put(taskId, task); + + if (previous != null) { + throw new IllegalArgumentException( + String.format("Discovered a duplicate task [%s] [%s] [%s]", name, task, previous)); + } + } return tasks; }