From 5dc8f5b71cb6b2366601809829136a35785f708e Mon Sep 17 00:00:00 2001
From: Yaliang Wu <ylwu@amazon.com>
Date: Fri, 27 Oct 2023 00:20:58 -0700
Subject: [PATCH] fix register client API (#1560)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
---
 .../ml/client/MachineLearningNodeClient.java  | 17 +++++++++-----
 .../register/MLRegisterModelResponse.java     | 22 +++++++++++++++++++
 2 files changed, 33 insertions(+), 6 deletions(-)

diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java
index e1fd6445a2..ac73b397e7 100644
--- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java
+++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java
@@ -230,12 +230,7 @@ public void searchTask(SearchRequest searchRequest, ActionListener<SearchRespons
     @Override
     public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterModelResponse> listener) {
         MLRegisterModelRequest registerRequest = new MLRegisterModelRequest(mlInput);
-        client
-            .execute(
-                MLRegisterModelAction.INSTANCE,
-                registerRequest,
-                ActionListener.wrap(listener::onResponse, e -> { listener.onFailure(e); })
-            );
+        client.execute(MLRegisterModelAction.INSTANCE, registerRequest, getMLRegisterModelResponseActionListener(listener));
     }
 
     @Override
@@ -266,6 +261,16 @@ private ActionListener<MLTaskResponse> getMlPredictionTaskResponseActionListener
         return actionListener;
     }
 
+    private ActionListener<MLRegisterModelResponse> getMLRegisterModelResponseActionListener(
+        ActionListener<MLRegisterModelResponse> listener
+    ) {
+        ActionListener<MLRegisterModelResponse> actionListener = wrapActionListener(listener, res -> {
+            MLRegisterModelResponse registerModelResponse = MLRegisterModelResponse.fromActionResponse(res);
+            return registerModelResponse;
+        });
+        return actionListener;
+    }
+
     private <T extends ActionResponse> ActionListener<T> wrapActionListener(
         final ActionListener<T> listener,
         final Function<ActionResponse, T> recreate
diff --git a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java
index c7baa9b3a6..18c64c6c5f 100644
--- a/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java
+++ b/common/src/main/java/org/opensearch/ml/common/transport/register/MLRegisterModelResponse.java
@@ -7,13 +7,19 @@
 
 import lombok.Getter;
 import org.opensearch.core.action.ActionResponse;
+import org.opensearch.core.common.io.stream.InputStreamStreamInput;
+import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
 import org.opensearch.core.common.io.stream.StreamInput;
 import org.opensearch.core.common.io.stream.StreamOutput;
 import org.opensearch.core.xcontent.ToXContent;
 import org.opensearch.core.xcontent.ToXContentObject;
 import org.opensearch.core.xcontent.XContentBuilder;
+import org.opensearch.ml.common.transport.MLTaskResponse;
 
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
 import java.io.IOException;
+import java.io.UncheckedIOException;
 
 @Getter
 public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject {
@@ -61,4 +67,20 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
         builder.endObject();
         return builder;
     }
+
+    public static MLRegisterModelResponse fromActionResponse(ActionResponse actionResponse) {
+        if (actionResponse instanceof MLRegisterModelResponse) {
+            return (MLRegisterModelResponse) actionResponse;
+        }
+
+        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
+             OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
+            actionResponse.writeTo(osso);
+            try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
+                return new MLRegisterModelResponse(input);
+            }
+        } catch (IOException e) {
+            throw new UncheckedIOException("failed to parse ActionResponse into MLRegisterModelResponse", e);
+        }
+    }
 }