From 5dc8f5b71cb6b2366601809829136a35785f708e Mon Sep 17 00:00:00 2001 From: Yaliang Wu <ylwu@amazon.com> Date: Fri, 27 Oct 2023 00:20:58 -0700 Subject: [PATCH] fix register client API (#1560) Signed-off-by: Yaliang Wu <ylwu@amazon.com> --- .../ml/client/MachineLearningNodeClient.java | 17 +++++++++----- .../register/MLRegisterModelResponse.java | 22 +++++++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index e1fd6445a2..ac73b397e7 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -230,12 +230,7 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons @Override public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) { MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput); - client - .execute( - MLRegisterModelAction.INSTANCE, - registerRequest, - ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); }) - ); + client.execute(MLRegisterModelAction.INSTANCE, registerRequest, getMLRegisterModelResponseActionListener(listener)); } @Override @@ -266,6 +261,16 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener return actionListener; } + private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseActionListener( + ActionListener<MLRegisterModelResponse> listener + ) { + ActionListener<MLRegisterModelResponse> actionListener = wrapActionListener(listener, res -> { + MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res); + return registerModelResponse; + }); + return actionListener; + } + private <T extends ActionResponse> ActionListener<T> wrapActionListener( final ActionListener<T> listener, final Function<ActionResponse, T> recreate diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java index c7baa9b3a6..18c64c6c5f 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java @@ -7,13 +7,19 @@ import lombok.Getter; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.transport.MLTaskResponse; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; @Getter public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject { @@ -61,4 +67,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par builder.endObject(); return builder; } + + public static MLRegisterModelResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof MLRegisterModelResponse) { + return (MLRegisterModelResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new MLRegisterModelResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelResponse", e); + } + } }