diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 4b197d99b..1b1875177 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -18,13 +18,14 @@ import java.util.HashMap; import java.util.Map; +import java.util.function.Supplier; /** * Generates instances implementing {@link WorkflowStep}. */ public class WorkflowStepFactory { - private final Map stepMap = new HashMap<>(); + private final Map> stepMap = new HashMap<>(); /** * Instantiate this class. @@ -42,21 +43,21 @@ public WorkflowStepFactory( MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { - stepMap.put(NoOpStep.NAME, new NoOpStep()); - stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler)); - stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); + stepMap.put(NoOpStep.NAME, NoOpStep::new); + stepMap.put(CreateIndexStep.NAME, () -> new CreateIndexStep(clusterService, client, flowFrameworkIndicesHandler)); + stepMap.put(CreateIngestPipelineStep.NAME, () -> new CreateIngestPipelineStep(client, flowFrameworkIndicesHandler)); stepMap.put( RegisterLocalModelStep.NAME, - new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) + () -> new RegisterLocalModelStep(settings, clusterService, mlClient, flowFrameworkIndicesHandler) ); - stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); - stepMap.put(UndeployModelStep.NAME, new UndeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); - stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(ToolStep.NAME, new ToolStep()); - stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); + stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeployModelStep.NAME, () -> new DeployModelStep(mlClient)); + stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); + stepMap.put(ModelGroupStep.NAME, () -> new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(ToolStep.NAME, ToolStep::new); + stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient)); } /** @@ -66,7 +67,7 @@ public WorkflowStepFactory( */ public WorkflowStep createStep(String type) { if (stepMap.containsKey(type)) { - return stepMap.get(type); + return stepMap.get(type).get(); } throw new FlowFrameworkException("Workflow step type [" + type + "] is not implemented.", RestStatus.NOT_IMPLEMENTED); } @@ -75,7 +76,7 @@ public WorkflowStep createStep(String type) { * Gets the step map * @return a read-only copy of the step map */ - public Map getStepMap() { + public Map> getStepMap() { return Map.copyOf(this.stepMap); } }