Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Register and Deploy Model Step for remote model #52

Merged
merged 13 commits into from
Oct 12, 2023
3 changes: 3 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ publishing {
allprojects {
group = opensearch_group
version = "${opensearch_build}"
}

java {
targetCompatibility = JavaVersion.VERSION_11
sourceCompatibility = JavaVersion.VERSION_11
}
Expand Down
38 changes: 38 additions & 0 deletions gradle.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#


# Enable build caching
org.gradle.caching=true
org.gradle.warning.mode=none
org.gradle.parallel=true
# Workaround for https://github.com/diffplug/spotless/issues/834
org.gradle.jvmargs=-Xmx3g -XX:+HeapDumpOnOutOfMemoryError -Xss2m \
--add-exports jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.file=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.parser=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED \
--add-exports jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED
options.forkOptions.memoryMaximumSize=2g

# Disable duplicate project id detection
# See https://docs.gradle.org/current/userguide/upgrading_version_6.html#duplicate_project_names_may_cause_publication_to_fail
systemProp.org.gradle.dependency.duplicate.project.detection=false

# Enforce the build to fail on deprecated gradle api usage
systemProp.org.gradle.warning.mode=fail

# forcing to use TLS1.2 to avoid failure in vault
# see https://github.com/hashicorp/vault/issues/8750#issuecomment-631236121
systemProp.jdk.tls.client.protocols=TLSv1.2

# jvm args for faster test execution by default
systemProp.tests.jvm.argline=-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m
2 changes: 1 addition & 1 deletion src/main/java/demo/DemoWorkflowStep.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {
CompletableFuture.runAsync(() -> {
try {
Thread.sleep(this.delay);
future.complete(null);
future.complete(WorkflowData.EMPTY);
} catch (InterruptedException e) {
future.completeExceptionally(e);
}
Expand Down
1 change: 0 additions & 1 deletion src/main/java/demo/TemplateParseDemo.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ public static void main(String[] args) throws IOException {
}
ClusterService clusterService = new ClusterService(null, null, null);
Client client = new NodeClient(null, null);

WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client);
ThreadPool threadPool = new ThreadPool(Settings.EMPTY);
WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
*/
package org.opensearch.flowframework.client;

import org.opensearch.client.node.NodeClient;
import org.opensearch.client.Client;
import org.opensearch.ml.client.MachineLearningNodeClient;

/**
Expand All @@ -22,12 +22,12 @@ private MLClient() {}
/**
* Creates machine learning client.
*
* @param nodeClient node client of OpenSearch.
* @param client client of OpenSearch.
* @return machine learning client from ml-commons.
*/
public static MachineLearningNodeClient createMLClient(NodeClient nodeClient) {
public static MachineLearningNodeClient createMLClient(Client client) {
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
if (INSTANCE == null) {
INSTANCE = new MachineLearningNodeClient(nodeClient);
INSTANCE = new MachineLearningNodeClient(client);
}
return INSTANCE;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

/**
* Step to deploy a model
*/
public class DeployModelStep implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModelStep.class);

private Client client;
private static final String MODEL_ID = "model_id";
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
static final String NAME = "deploy_model";

/**
* Instantiate this class
* @param client client to instantiate MLClient
*/
public DeployModelStep(Client client) {
this.client = client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
logger.info("Model deployment state {}", mlDeployModelResponse.getStatus());
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy_model_status", mlDeployModelResponse.getStatus())))

Check warning on line 53 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L51-L53

Added lines #L51 - L53 were not covered by tests
);
}

Check warning on line 55 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L55

Added line #L55 was not covered by tests

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
}
};

String modelId = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();
for (Map.Entry<String, Object> entry : content.entrySet()) {
if (entry.getKey() == MODEL_ID) {
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
modelId = (String) content.get(MODEL_ID);
}

}
}
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
machineLearningNodeClient.deploy(modelId, actionListener);
owaiskazi19 marked this conversation as resolved.
Show resolved Hide resolved
return deployModelFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 81 in src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java#L81

Added line #L81 was not covered by tests
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.task.MLTaskGetResponse;

/**
* Step to get modelID of a registered local model
*/
@SuppressForbidden(reason = "This class is for the future work of registering local model")
public class GetTask {

private static final Logger logger = LogManager.getLogger(GetTask.class);

Check warning on line 26 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L26

Added line #L26 was not covered by tests
private MachineLearningNodeClient machineLearningNodeClient;
private String taskId;

/**
* Instantiate this class
* @param machineLearningNodeClient client to instantiate ml-commons APIs
* @param taskId taskID of the model
*/
public GetTask(MachineLearningNodeClient machineLearningNodeClient, String taskId) {
this.machineLearningNodeClient = machineLearningNodeClient;
this.taskId = taskId;
}

Check warning on line 38 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L35-L38

Added lines #L35 - L38 were not covered by tests

/**
* Invokes get task API of ml-commons
*/
public void getTask() {

ActionListener<MLTask> actionListener = new ActionListener<>() {

Check warning on line 45 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L45

Added line #L45 was not covered by tests
@Override
public void onResponse(MLTask mlTask) {
if (mlTask.getState() == MLTaskState.COMPLETED) {
logger.info("Model registration successful");
MLTaskGetResponse response = MLTaskGetResponse.builder().mlTask(mlTask).build();
logger.info("Response from task {}", response);

Check warning on line 51 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L49-L51

Added lines #L49 - L51 were not covered by tests
}
}

Check warning on line 53 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L53

Added line #L53 was not covered by tests

@Override
public void onFailure(Exception e) {
logger.error("Model registration failed");
}

Check warning on line 58 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L57-L58

Added lines #L57 - L58 were not covered by tests
};

machineLearningNodeClient.getTask(taskId, actionListener);

Check warning on line 61 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L61

Added line #L61 was not covered by tests

}

Check warning on line 63 in src/main/java/org/opensearch/flowframework/workflow/GetTask.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/GetTask.java#L63

Added line #L63 was not covered by tests

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

/**
* Step to register a remote model
*/
public class RegisterModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private Client client;

static final String NAME = "register_model";

private static final String FUNCTION_NAME = "function_name";
private static final String MODEL_NAME = "name";
private static final String MODEL_VERSION = "model_version";
private static final String MODEL_GROUP_ID = "model_group_id";
private static final String DESCRIPTION = "description";
private static final String CONNECTOR_ID = "connector_id";
private static final String MODEL_FORMAT = "model_format";
private static final String MODEL_CONFIG = "model_config";

/**
* Instantiate this class
* @param client client to instantiate MLClient
*/
public RegisterModelStep(Client client) {
this.client = client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> registerModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client);

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
logger.info("Model registration successful");
registerModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("model_id", mlRegisterModelResponse.getModelId()),
Map.entry("register_model_status", mlRegisterModelResponse.getStatus())

Check warning on line 74 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L72-L74

Added lines #L72 - L74 were not covered by tests
)
)
);
}

Check warning on line 78 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L78

Added line #L78 was not covered by tests

@Override
public void onFailure(Exception e) {
logger.error("Failed to register model");
registerModelFuture.completeExceptionally(new IOException("Failed to register model "));
}
};

FunctionName functionName = null;
String modelName = null;
String modelVersion = null;
String modelGroupId = null;
String connectorId = null;
String description = null;
MLModelFormat modelFormat = null;
MLModelConfig modelConfig = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case FUNCTION_NAME:
functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
break;
case MODEL_NAME:
modelName = (String) content.get(MODEL_NAME);
break;
case MODEL_VERSION:
modelVersion = (String) content.get(MODEL_VERSION);
break;

Check warning on line 109 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L108-L109

Added lines #L108 - L109 were not covered by tests
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;

Check warning on line 112 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L111-L112

Added lines #L111 - L112 were not covered by tests
case MODEL_FORMAT:
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT));
break;

Check warning on line 115 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L114-L115

Added lines #L114 - L115 were not covered by tests
case MODEL_CONFIG:
modelConfig = (MLModelConfig) content.get(MODEL_CONFIG);
break;

Check warning on line 118 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L117-L118

Added lines #L117 - L118 were not covered by tests
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
break;
default:
break;

}
}
}

if (Stream.of(functionName, modelName, description, connectorId).allMatch(x -> x != null)) {

// TODO: Add model Config and type cast correctly
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.description(description)
.connectorId(connectorId)
.build();

machineLearningNodeClient.register(mlInput, actionListener);
}

return registerModelFuture;
}

@Override
public String getName() {
return NAME;

Check warning on line 150 in src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java#L150

Added line #L150 was not covered by tests
}
}
Loading