Skip to content

Commit

Permalink
Test ppl transportaction
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <[email protected]>
  • Loading branch information
zane-neo committed Nov 24, 2023
1 parent 96aa969 commit 5a15ae4
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 1 deletion.
3 changes: 3 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ dependencies {

configurations.all {
resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.21.9'
// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-annotations:2.16.0'
// resolutionStrategy.force 'org.opensearch.client:opensearch-rest-client:2.12.0-SNAPSHOT'
// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-databind:2.16.0'
}

jacocoTestReport {
Expand Down
24 changes: 24 additions & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ dependencies {
implementation project(':opensearch-ml-memory')

implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation (group: 'opensearch-sql', name: 'opensearch-sql', version: "${common_utils_version}") {
exclude module: 'legacy'
exclude module: 'opensearch'
exclude module: 'prometheus'
exclude module: 'datasources'
exclude module: 'spark'
}
implementation (group: 'opensearch-sql', name: 'ppl', version: "${common_utils_version}") {
exclude group: 'org.reflections', module: 'reflections'
exclude group: 'com.google.guava', module: 'guava'
exclude group: 'org.json', module: 'json'
exclude module: 'common'
exclude module: 'core'
}
implementation (group: 'opensearch-sql', name: 'protocol', version: "${common_utils_version}") {
exclude group: 'com.google.guava', module: 'guava'
exclude group: 'com.fasterxml.jackson.core', module: 'jackson-core'
exclude group: 'com.fasterxml.jackson.core', module: 'jackson-databind'
exclude group: 'com.fasterxml.jackson.dataformat', module: 'jackson-dataformat-cbor'
exclude group: 'com.google.code.gson', module: 'gson'
}
implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}"
implementation "org.opensearch:common-utils:${common_utils_version}"
implementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
Expand Down Expand Up @@ -330,6 +351,9 @@ configurations.all {
resolutionStrategy.force 'org.apache.httpcomponents:httpclient:4.5.14'
resolutionStrategy.force 'commons-codec:commons-codec:1.15'
resolutionStrategy.force 'org.slf4j:slf4j-api:1.7.36'
// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-annotations:2.16.0'
// resolutionStrategy.force 'org.opensearch.client:opensearch-rest-client:2.12.0-SNAPSHOT'
// resolutionStrategy.force 'com.fasterxml.jackson.core:jackson-databind:2.16.0'
}

apply plugin: 'com.netflix.nebula.ospackage'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
//import org.opensearch.ml.rest.MyRestPPLQueryAction;
import org.opensearch.ml.rest.MyRestPPLQueryAction;
import org.opensearch.ml.rest.RestMLCreateConnectorAction;
import org.opensearch.ml.rest.RestMLDeleteConnectorAction;
import org.opensearch.ml.rest.RestMLDeleteModelAction;
Expand Down Expand Up @@ -554,6 +556,7 @@ public List<RestHandler> getRestHandlers(
RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction();
RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction();
RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting);
MyRestPPLQueryAction restPPLQueryAction = new MyRestPPLQueryAction();
return ImmutableList
.of(
restMLStatsAction,
Expand Down Expand Up @@ -587,7 +590,8 @@ public List<RestHandler> getRestHandlers(
restCreateInteractionAction,
restListInteractionsAction,
restDeleteConversationAction,
restMLUpdateConnectorAction
restMLUpdateConnectorAction,
restPPLQueryAction
);
}

Expand Down
120 changes: 120 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/rest/MyRestPPLQueryAction.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import com.google.common.collect.ImmutableList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.plugin.request.PPLQueryRequestFactory;
import org.opensearch.sql.plugin.transport.PPLQueryAction;
import org.opensearch.sql.plugin.transport.TransportPPLQueryRequest;
import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse;


import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.core.rest.RestStatus.OK;

public class MyRestPPLQueryAction extends BaseRestHandler {
public static final String QUERY_API_ENDPOINT = "_ml/_ppl";
public static final String EXPLAIN_API_ENDPOINT = "_ml/_ppl/_explain";
public static final String LEGACY_QUERY_API_ENDPOINT = "_ml/_opendistro/_ppl";
public static final String LEGACY_EXPLAIN_API_ENDPOINT = "_ml/_opendistro/_ppl/_explain";

private static final Logger LOG = LogManager.getLogger();

/** Constructor of RestPPLQueryAction. */
public MyRestPPLQueryAction() {
super();
}

@Override
public List<Route> routes() {
return List.of(new Route(RestRequest.Method.POST, QUERY_API_ENDPOINT), new Route(RestRequest.Method.POST, EXPLAIN_API_ENDPOINT));
}

// @Override
// public List<ReplacedRoute> replacedRoutes() {
// return Arrays.asList(
// new ReplacedRoute(
// RestRequest.Method.POST, QUERY_API_ENDPOINT,
// RestRequest.Method.POST, LEGACY_QUERY_API_ENDPOINT),
// new ReplacedRoute(
// RestRequest.Method.POST, EXPLAIN_API_ENDPOINT,
// RestRequest.Method.POST, LEGACY_EXPLAIN_API_ENDPOINT));
// }

@Override
public String getName() {
return "ml_ppl_query_action";
}

@Override
protected Set<String> responseParams() {
Set<String> responseParams = new HashSet<>(super.responseParams());
responseParams.addAll(Arrays.asList("format", "sanitize"));
return responseParams;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) {
TransportPPLQueryRequest transportPPLQueryRequest =
new TransportPPLQueryRequest(PPLQueryRequestFactory.getPPLRequest(request));

return channel ->
nodeClient.execute(
PPLQueryAction.INSTANCE,
transportPPLQueryRequest,
new ActionListener<>() {
@Override
public void onResponse(TransportPPLQueryResponse response) {
sendResponse(channel, OK, response.getResult());
}

@Override
public void onFailure(Exception e) {
if (e instanceof IllegalAccessException) {
LOG.error("Error happened during query handling", e);
reportError(channel, e, BAD_REQUEST);
} else if (transportPPLQueryRequest.isExplainRequest()) {
LOG.error("Error happened during explain", e);
sendResponse(
channel,
INTERNAL_SERVER_ERROR,
"Failed to explain the query due to error: " + e.getMessage());
} else if (e instanceof OpenSearchSecurityException) {
OpenSearchSecurityException exception = (OpenSearchSecurityException) e;
reportError(channel, exception, exception.status());
} else {
LOG.error("Error happened during query handling", e);
reportError(channel, e, INTERNAL_SERVER_ERROR);
}
}
});
}

private void sendResponse(RestChannel channel, RestStatus status, String content) {
channel.sendResponse(new BytesRestResponse(status, "application/json; charset=UTF-8", content));
}

private void reportError(final RestChannel channel, final Exception e, final RestStatus status) {
channel.sendResponse(new BytesRestResponse(status, e.getMessage()));
}
}

0 comments on commit 5a15ae4

Please sign in to comment.