Skip to content

Commit

Permalink
Added NodeClient
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 4, 2023
1 parent ee6ddc3 commit 353a322
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 22 deletions.
3 changes: 2 additions & 1 deletion src/main/java/demo/Demo.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ public static void main(String[] args) throws IOException {
return;
}
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(client);
NodeClient nodeClient = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(client, nodeClient);

ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public static void main(String[] args) throws IOException {
return;
}
Client client = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(client);
NodeClient nodeClient = new NodeClient(null, null);
WorkflowStepFactory factory = new WorkflowStepFactory(client, nodeClient);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import com.google.common.collect.ImmutableList;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.env.Environment;
Expand Down Expand Up @@ -51,7 +53,10 @@ public Collection<Object> createComponents(
IndexNameExpressionResolver indexNameExpressionResolver,
Supplier<RepositoriesService> repositoriesServiceSupplier
) {
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(client);
Settings settings = environment.settings();
// TODO: Creating NodeClient is a temporary fix until we get the NodeClient from the provision API
NodeClient nodeClient = new NodeClient(settings, threadPool);
WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(client, nodeClient);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool);

return ImmutableList.of(workflowStepFactory, workflowProcessSorter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
Expand All @@ -21,15 +20,15 @@
import java.util.Map;
import java.util.concurrent.CompletableFuture;

public class DeployModel implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModel.class);
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private NodeClient nodeClient;
private static final String MODEL_ID = "model_id";
static final String NAME = "deploy_model";

public DeployModel(Client client) {
this.nodeClient = (NodeClient) client;
public DeployModelStep(NodeClient nodeClient) {
this.nodeClient = nodeClient;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;

@SuppressForbidden(reason = "This class is for the future work of registering local model")
public class GetTask {

private static final Logger logger = LogManager.getLogger(GetTask.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
Expand All @@ -20,7 +19,6 @@
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.threadpool.Scheduler;

import java.io.IOException;
import java.util.List;
Expand All @@ -35,7 +33,6 @@ public class RegisterModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private NodeClient nodeClient;
private volatile Scheduler.Cancellable scheduledFuture;

static final String NAME = "register_model";

Expand All @@ -48,8 +45,8 @@ public class RegisterModelStep implements WorkflowStep {
private static final String MODEL_FORMAT = "model_format";
private static final String MODEL_CONFIG = "model_config";

public RegisterModelStep(Client client) {
this.nodeClient = (NodeClient) client;
public RegisterModelStep(NodeClient nodeClient) {
this.nodeClient = nodeClient;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.flowframework.workflow;

import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;

import java.util.HashMap;
import java.util.List;
Expand All @@ -29,15 +30,15 @@ public class WorkflowStepFactory {
*
* @param client The OpenSearch client steps can use
*/
public WorkflowStepFactory(Client client) {
populateMap(client);
public WorkflowStepFactory(Client client, NodeClient nodeClient) {
populateMap(client, nodeClient);
}

private void populateMap(Client client) {
private void populateMap(Client client, NodeClient nodeClient) {
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(client));
stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client));
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client));
stepMap.put(DeployModel.NAME, new DeployModel(client));
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(nodeClient));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(nodeClient));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.env.Environment;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -41,7 +43,8 @@ public void tearDown() throws Exception {

public void testPlugin() throws IOException {
try (FlowFrameworkPlugin ffp = new FlowFrameworkPlugin()) {
assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, null, null, null, null, null).size());
Environment env = new Environment(Settings.builder().put("path.home", "dummy").build(), null);
assertEquals(2, ffp.createComponents(client, null, threadPool, null, null, null, env, null, null, null, null).size());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import static org.mockito.Mockito.*;

@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class DeployModelTests extends OpenSearchTestCase {
public class DeployModelStepTests extends OpenSearchTestCase {

private WorkflowData inputData = WorkflowData.EMPTY;

Expand Down Expand Up @@ -59,7 +59,7 @@ public void testDeployModel() {
String status = MLTaskState.CREATED.name();
MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL;

DeployModel deployModel = new DeployModel(nodeClient);
DeployModelStep deployModel = new DeployModelStep(nodeClient);

ArgumentCaptor<ActionListener<MLDeployModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.client.AdminClient;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.model.TemplateTestJsonUtil;
import org.opensearch.flowframework.model.Workflow;
Expand Down Expand Up @@ -58,10 +59,11 @@ private static List<String> parse(String json) throws IOException {
public static void setup() {
AdminClient adminClient = mock(AdminClient.class);
Client client = mock(Client.class);
NodeClient nodeClient = mock(NodeClient.class);
when(client.admin()).thenReturn(adminClient);

testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName());
WorkflowStepFactory factory = new WorkflowStepFactory(client);
WorkflowStepFactory factory = new WorkflowStepFactory(client, nodeClient);
workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool);
}

Expand Down

0 comments on commit 353a322

Please sign in to comment.