From 81ba9fa9f2659bb4dde2b6ef41a49f1b04a049a6 Mon Sep 17 00:00:00 2001 From: Andres Gomez Ferrer Date: Mon, 22 Jan 2024 17:18:17 +0100 Subject: [PATCH] Fix multiple subworkflows mocks Signed-off-by: Andres Gomez Ferrer --- .../flytekit/testing/SdkTestingExecutor.java | 3 +- .../flytekit/testing/TestingWorkflow.java | 21 ++++++-- .../testing/MockSubWorkflowsTest.java | 54 ++++++++++++++++++- 3 files changed, 71 insertions(+), 7 deletions(-) diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java index 98194915..0b7fccb7 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/SdkTestingExecutor.java @@ -406,12 +406,11 @@ public SdkTestingExecutor withWorkflowOutput( // replace workflow SdkWorkflow mockWorkflow = - new TestingWorkflow<>(inputType, outputType, fixedTask.fixedOutputs); + new TestingWorkflow<>(inputType, outputType, fixedTask.fixedOutputs, workflow.getName()); return toBuilder() .putWorkflowTemplate(workflow.getName(), mockWorkflow.toIdlTemplate()) .putFixedTask(workflow.getName(), fixedTask) - .putFixedTask(TestingWorkflow.TestingSdkRunnableTask.class.getName(), fixedTask) .build(); } diff --git a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java index f5b2bd5f..1148f21e 100644 --- a/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java +++ b/flytekit-testing/src/main/java/org/flyte/flytekit/testing/TestingWorkflow.java @@ -26,17 +26,22 @@ class TestingWorkflow extends SdkWorkflow { private final Map outputs; + private final String name; TestingWorkflow( - SdkType inputType, SdkType outputType, Map outputs) { + SdkType inputType, + SdkType outputType, + Map outputs, + String name) { super(inputType, outputType); this.outputs = outputs; + this.name = name; } @Override public OutputT expand(SdkWorkflowBuilder builder, InputT input) { return builder - .apply(new TestingSdkRunnableTask<>(getInputType(), getOutputType(), outputs), input) + .apply(new TestingSdkRunnableTask<>(getInputType(), getOutputType(), outputs, name), input) .getOutputs(); } @@ -45,11 +50,21 @@ public static class TestingSdkRunnableTask private static final long serialVersionUID = 6106269076155338045L; private final Map outputs; + private final String name; + + @Override + public String getName() { + return name; + } public TestingSdkRunnableTask( - SdkType inputType, SdkType outputType, Map outputs) { + SdkType inputType, + SdkType outputType, + Map outputs, + String name) { super(inputType, outputType); this.outputs = outputs; + this.name = name; } @Override diff --git a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java index 6e96ea2a..2bfc70dc 100644 --- a/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java +++ b/flytekit-testing/src/test/java/org/flyte/flytekit/testing/MockSubWorkflowsTest.java @@ -41,6 +41,12 @@ public void test() { SubWorkflowInputs.create(SdkBindingDataFactory.of(1)), JacksonSdkType.of(SubWorkflowOutputs.class), SubWorkflowOutputs.create(SdkBindingDataFactory.of(10))) + .withWorkflowOutput( + new SubWorkflow1(), + JacksonSdkType.of(SubWorkflowInputs1.class), + SubWorkflowInputs1.create(SdkBindingDataFactory.of(11)), + JacksonSdkType.of(SubWorkflowOutputs1.class), + SubWorkflowOutputs1.create(SdkBindingDataFactory.of(110))) .withWorkflowOutput( new SubWorkflow(), JacksonSdkType.of(SubWorkflowInputs.class), @@ -64,10 +70,22 @@ public SubWorkflowOutputs expand(SdkWorkflowBuilder builder, SubWorkflowInputs i var subOut1 = builder .apply( - "sub", new SubWorkflow(), SubWorkflowInputs.create(SdkBindingDataFactory.of(1))) + "subworkflow-1", + new SubWorkflow(), + SubWorkflowInputs.create(SdkBindingDataFactory.of(1))) .getOutputs(); builder - .apply("sub1", new SubWorkflow(), SubWorkflowInputs.create(SdkBindingDataFactory.of(2))) + .apply( + "subworkflow1-1", + new SubWorkflow1(), + SubWorkflowInputs1.create(SdkBindingDataFactory.of(11))) + .getOutputs(); + + builder + .apply( + "subworkflow-2", + new SubWorkflow(), + SubWorkflowInputs.create(SdkBindingDataFactory.of(2))) .getOutputs(); return SubWorkflowOutputs.create(subOut1.o()); @@ -104,4 +122,36 @@ public static MockSubWorkflowsTest.SubWorkflowOutputs create(SdkBindingData { + public SubWorkflow1() { + super( + JacksonSdkType.of(SubWorkflowInputs1.class), + JacksonSdkType.of(SubWorkflowOutputs1.class)); + } + + @Override + public SubWorkflowOutputs1 expand(SdkWorkflowBuilder builder, SubWorkflowInputs1 inputs) { + + return SubWorkflowOutputs1.create(inputs.a1()); + } + } + + @AutoValue + public abstract static class SubWorkflowInputs1 { + public abstract SdkBindingData a1(); + + public static MockSubWorkflowsTest.SubWorkflowInputs1 create(SdkBindingData a1) { + return new AutoValue_MockSubWorkflowsTest_SubWorkflowInputs1(a1); + } + } + + @AutoValue + public abstract static class SubWorkflowOutputs1 { + public abstract SdkBindingData o1(); + + public static MockSubWorkflowsTest.SubWorkflowOutputs1 create(SdkBindingData o1) { + return new AutoValue_MockSubWorkflowsTest_SubWorkflowOutputs1(o1); + } + } }