Skip to content

Commit

Permalink
More steps using the new input parsing
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 1, 2023
1 parent 2c82174 commit 80a6fa6
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.model.MLModelConfig;
Expand All @@ -28,10 +29,8 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

Expand Down Expand Up @@ -81,12 +80,6 @@ public CompletableFuture<WorkflowData> execute(

CompletableFuture<WorkflowData> registerLocalModelFuture = new CompletableFuture<>();

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) {
Expand All @@ -105,76 +98,41 @@ public void onFailure(Exception e) {
}
};

String modelName = null;
String modelVersion = null;
String description = null;
MLModelFormat modelFormat = null;
String modelGroupId = null;
String modelContentHashValue = null;
String modelType = null;
String embeddingDimension = null;
FrameworkType frameworkType = null;
String allConfig = null;
String url = null;

for (WorkflowData workflowData : data) {
Map<String, Object> content = workflowData.getContent();

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelName = (String) content.get(NAME_FIELD);
break;
case VERSION_FIELD:
modelVersion = (String) content.get(VERSION_FIELD);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case MODEL_FORMAT:
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT));
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case MODEL_TYPE:
modelType = (String) content.get(MODEL_TYPE);
break;
case EMBEDDING_DIMENSION:
embeddingDimension = (String) content.get(EMBEDDING_DIMENSION);
break;
case FRAMEWORK_TYPE:
frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE));
break;
case ALL_CONFIG:
allConfig = (String) content.get(ALL_CONFIG);
break;
case MODEL_CONTENT_HASH_VALUE:
modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE);
break;
case URL:
url = (String) content.get(URL);
break;
default:
break;

}
}
}

if (Stream.of(
modelName,
modelVersion,
modelFormat,
modelGroupId,
modelType,
embeddingDimension,
frameworkType,
modelContentHashValue,
url
).allMatch(x -> x != null)) {
Set<String> requiredKeys = Set.of(
NAME_FIELD,
VERSION_FIELD,
MODEL_FORMAT,
MODEL_GROUP_ID,
MODEL_TYPE,
EMBEDDING_DIMENSION,
FRAMEWORK_TYPE,
MODEL_CONTENT_HASH_VALUE,
URL
);
Set<String> optionalKeys = Set.of(DESCRIPTION_FIELD, ALL_CONFIG);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

// Create Model configudation
String modelName = (String) inputs.get(NAME_FIELD);
String modelVersion = (String) inputs.get(VERSION_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
MLModelFormat modelFormat = MLModelFormat.from((String) inputs.get(MODEL_FORMAT));
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String modelContentHashValue = (String) inputs.get(MODEL_CONTENT_HASH_VALUE);
String modelType = (String) inputs.get(MODEL_TYPE);
String embeddingDimension = (String) inputs.get(EMBEDDING_DIMENSION);
FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE));
String allConfig = (String) inputs.get(ALL_CONFIG);
String url = (String) inputs.get(URL);

// Create Model configuration
TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder()
.modelType(modelType)
.embeddingDimension(Integer.valueOf(embeddingDimension))
Expand All @@ -200,12 +158,9 @@ public void onFailure(Exception e) {
MLRegisterModelInput mlInput = mlInputBuilder.build();

mlClient.register(mlInput, actionListener);
} else {
registerLocalModelFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
} catch (FlowFrameworkException e) {
registerLocalModelFuture.completeExceptionally(e);
}

return registerLocalModelFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,18 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;
import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD;
Expand Down Expand Up @@ -88,48 +85,23 @@ public void onFailure(Exception e) {
}
};

String modelName = null;
FunctionName functionName = null;
String modelGroupId = null;
String description = null;
String connectorId = null;

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());

// TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149

for (WorkflowData workflowData : data) {

Map<String, Object> content = workflowData.getContent();
for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
modelName = (String) content.get(NAME_FIELD);
break;
case FUNCTION_NAME:
functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case DESCRIPTION_FIELD:
description = (String) content.get(DESCRIPTION_FIELD);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
break;
default:
break;

}
}
}
Set<String> requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID);
Set<String> optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD);

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

if (Stream.of(modelName, functionName, connectorId).allMatch(x -> x != null)) {
String modelName = (String) inputs.get(NAME_FIELD);
FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
.functionName(functionName)
Expand All @@ -145,12 +117,10 @@ public void onFailure(Exception e) {
MLRegisterModelInput mlInput = builder.build();

mlClient.register(mlInput, actionListener);
} else {
registerRemoteModelFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
}

} catch (FlowFrameworkException e) {
registerRemoteModelFuture.completeExceptionally(e);
}
return registerRemoteModelFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,27 @@ public void testRegisterLocalModelTaskFailure() {
public void testMissingInputs() {
CompletableFuture<WorkflowData> future = registerLocalModelStep.execute(
"nodeId",
WorkflowData.EMPTY,
new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"),
Collections.emptyMap(),
Collections.emptyMap()
);
assertTrue(future.isDone());
assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Required fields are not provided", ex.getCause().getMessage());
assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs ["));
for (String s : new String[] {
"model_format",
"name",
"model_type",
"embedding_dimension",
"framework_type",
"model_group_id",
"version",
"url",
"model_content_hash_value" }) {
assertTrue(ex.getCause().getMessage().contains(s));
}
assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]"));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,19 @@ public void testRegisterRemoteModelFailure() {
public void testMissingInputs() {
CompletableFuture<WorkflowData> future = this.registerRemoteModelStep.execute(
"nodeId",
WorkflowData.EMPTY,
new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"),
Collections.emptyMap(),
Collections.emptyMap()
);
assertTrue(future.isDone());
assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Required fields are not provided", ex.getCause().getMessage());
assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs ["));
for (String s : new String[] { "name", "function_name", "connector_id" }) {
assertTrue(ex.getCause().getMessage().contains(s));
}
assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]"));
}

}

0 comments on commit 80a6fa6

Please sign in to comment.