Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to select SQL engine with each query. #160

Draft
wants to merge 4 commits into
base: integ-add-engine-selector
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/EngineSwitchIT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/


package org.opensearch.sql.sql;

import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_CALCS;
import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT;
import static org.opensearch.sql.util.TestUtils.getResponseBody;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.Locale;
import com.google.common.collect.Lists;
import lombok.SneakyThrows;
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.client.Request;
import org.opensearch.client.RequestOptions;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.sql.legacy.SQLIntegTestCase;

public class EngineSwitchIT extends SQLIntegTestCase {

@Override
public void init() throws Exception {
super.init();
loadIndex(Index.CALCS);
}

@Test
@SneakyThrows
public void test_no_param_set_v2_query() {
var result = executeQueryOnEngine(String.format("SELECT COUNT(*) FROM %s", TEST_INDEX_CALCS), null);
assertEquals(17, result.getJSONArray("datarows").getJSONArray(0).getInt(0));
assertTrue(findLogLineAfterAnotherLine(
"Request is handled by new SQL query engine with fallback option to legacy", List.of(
"Request is not supported and falling back to old SQL engine",
"Request is handled by old SQL engine only",
"Request is handled by new SQL query engine without fallback to legacy")));
}

@Test
@SneakyThrows
public void test_no_param_set_v1_query() {
var result = executeQueryOnEngine(String.format("SELECT COUNT(*) FROM %s, %s",
TEST_INDEX_CALCS, TEST_INDEX_CALCS), null);
assertEquals(17, result.getJSONArray("datarows").getJSONArray(0).getInt(0));
assertTrue(findLogLineAfterAnotherLine(
"Request is not supported and falling back to old SQL engine", List.of(
"Request is handled by new SQL query engine with fallback option to legacy",
"Request is handled by old SQL engine only",
"Request is handled by new SQL query engine without fallback to legacy")));
}

// Any value, but 'v1`, 'legacy', 'v2' interpreted as no value set and should work with fallback
@Test
@SneakyThrows
public void test_some_param_set_v2_query() {
var result = executeQueryOnEngine(String.format("SELECT COUNT(*) FROM %s", TEST_INDEX_CALCS), "fallback");
assertEquals(17, result.getJSONArray("datarows").getJSONArray(0).getInt(0));
assertTrue(findLogLineAfterAnotherLine(
"Request is handled by new SQL query engine with fallback option to legacy", List.of(
"Request is not supported and falling back to old SQL engine",
"Request is handled by old SQL engine only",
"Request is handled by new SQL query engine without fallback to legacy")));
}

@Test
@SneakyThrows
public void test_some_param_set_v1_query() {
var result = executeQueryOnEngine(String.format("SELECT COUNT(*) FROM %s, %s",
TEST_INDEX_CALCS, TEST_INDEX_CALCS), "null");
assertEquals(17, result.getJSONArray("datarows").getJSONArray(0).getInt(0));
assertTrue(findLogLineAfterAnotherLine(
"Request is not supported and falling back to old SQL engine", List.of(
"Request is handled by new SQL query engine with fallback option to legacy",
"Request is handled by old SQL engine only",
"Request is handled by new SQL query engine without fallback to legacy")));
}

@Test
@SneakyThrows
public void test_v2_param_set_v2_query() {
var result = executeQueryOnEngine("SELECT 1", "v2");
assertEquals(1, result.getJSONArray("datarows").getJSONArray(0).getInt(0));
assertTrue(findLogLineAfterAnotherLine(
"Request is handled by new SQL query engine without fallback to legacy", List.of(
"Request is not supported and falling back to old SQL engine",
"Request is handled by old SQL engine only",
"Request is handled by new SQL query engine with fallback option to legacy")));
}

@Test
@SneakyThrows
public void test_v1_param_set_v2_query() {
var exception = assertThrows(ResponseException.class, () -> executeQueryOnEngine("SELECT 1", "legacy"));
assertTrue(exception.getMessage().contains("Invalid SQL query"));
assertTrue(findLogLineAfterAnotherLine(
"Request is handled by old SQL engine only", List.of(
"Request is not supported and falling back to old SQL engine",
"Request is handled by new SQL query engine without fallback to legacy",
"Request is handled by new SQL query engine with fallback option to legacy")));
}

@Test
@SneakyThrows
public void test_v2_param_set_v1_query() {
var exception = assertThrows(ResponseException.class, () ->
executeQueryOnEngine(String.format("SELECT COUNT(*) FROM %s, %s",
TEST_INDEX_CALCS, TEST_INDEX_CALCS), "v2"));
assertTrue(exception.getMessage().contains("Invalid SQL query"));
assertTrue(findLogLineAfterAnotherLine(
"Request is handled by new SQL query engine without fallback to legacy", List.of(
"Request is not supported and falling back to old SQL engine",
"Request is handled by old SQL engine only",
"Request is handled by new SQL query engine with fallback option to legacy")));
}

@Test
@SneakyThrows
public void test_v1_param_set_v1_query() {
var result = executeQueryOnEngine(String.format("SELECT COUNT(*) FROM %s, %s",
TEST_INDEX_CALCS, TEST_INDEX_CALCS), "v1");
assertEquals(17, result.getJSONArray("datarows").getJSONArray(0).getInt(0));
assertTrue(findLogLineAfterAnotherLine(
"Request is handled by old SQL engine only", List.of(
"Request is not supported and falling back to old SQL engine",
"Request is handled by new SQL query engine without fallback to legacy",
"Request is handled by new SQL query engine with fallback option to legacy")));
}

/**
* Function looks for the given line after another line(s) in IT cluster log from the end.
* @return true if found.
*/
@SneakyThrows
private Boolean findLogLineAfterAnotherLine(String line, List<String> linesBefore) {
var logDir = getAllClusterSettings().query("/defaults/path.logs");
var lines = Files.readAllLines(Paths.get(logDir.toString(), "integTest.log"));
for (String logLine : Lists.reverse(lines)) {
if (logLine.contains(line)) {
return true;
}
for (var lineBefore : linesBefore) {
if (logLine.contains(lineBefore)) {
return false;
}
}
}
return false;
}

protected JSONObject executeQueryOnEngine(String query, String engine) throws IOException {
var endpoint = engine == null ? QUERY_API_ENDPOINT : QUERY_API_ENDPOINT + "?engine=" + engine;
Request request = new Request("POST", endpoint);
request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query));

RequestOptions.Builder restOptionsBuilder = RequestOptions.DEFAULT.toBuilder();
restOptionsBuilder.addHeader("Content-Type", "application/json");
request.setOptions(restOptionsBuilder);

Response response = client().performRequest(request);
return new JSONObject(getResponseBody(response));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,29 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
// Route request to new query engine if it's supported already
SQLQueryRequest newSqlRequest = new SQLQueryRequest(sqlRequest.getJsonContent(),
sqlRequest.getSql(), request.path(), request.params());
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) {
LOG.info("[{}] Request is handled by new SQL query engine",
QueryContext.getRequestId());

if (newSqlRequest.getEngine().toLowerCase().contains("v2")) {
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
LOG.info("[{}] Request is handled by new SQL query engine without fallback to legacy", QueryContext.getRequestId());
result.accept(channel);
} else {
LOG.debug("[{}] Request {} is not supported and falling back to old SQL engine",
QueryContext.getRequestId(), newSqlRequest);
} else if (newSqlRequest.getEngine().toLowerCase().contains("v1") ||
newSqlRequest.getEngine().toLowerCase().contains("legacy")) {
LOG.info("[{}] Request is handled by old SQL engine only",
QueryContext.getRequestId());
QueryAction queryAction = explainRequest(client, sqlRequest, format);
executeSqlRequest(request, queryAction, client, channel);
} else {
RestChannelConsumer result = newSqlQueryHandler.prepareRequest(newSqlRequest, client);
if (result != RestSQLQueryAction.NOT_SUPPORTED_YET) {
LOG.info("[{}] Request is handled by new SQL query engine with fallback option to legacy",
QueryContext.getRequestId());
result.accept(channel);
} else {
LOG.info("[{}] Request is not supported and falling back to old SQL engine",
QueryContext.getRequestId());
QueryAction queryAction = explainRequest(client, sqlRequest, format);
executeSqlRequest(request, queryAction, client, channel);
}
}
} catch (Exception e) {
logAndPublishMetrics(e);
Expand All @@ -180,7 +193,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
@Override
protected Set<String> responseParams() {
Set<String> responseParams = new HashSet<>(super.responseParams());
responseParams.addAll(Arrays.asList("sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize"));
responseParams.addAll(Arrays.asList("sql", "flat", "separator", "_score", "_type", "_id", "newLine", "format", "sanitize", "engine"));
return responseParams;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public void handleQueryThatCanSupport() {
new JSONObject("{\"query\": \"SELECT -123\"}"),
"SELECT -123",
QUERY_API_ENDPOINT,
"");
"", "");

RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService);
assertNotSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient));
Expand All @@ -69,7 +69,7 @@ public void handleExplainThatCanSupport() {
new JSONObject("{\"query\": \"SELECT -123\"}"),
"SELECT -123",
EXPLAIN_API_ENDPOINT,
"");
"", "");

RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService);
assertNotSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient));
Expand All @@ -82,7 +82,7 @@ public void skipQueryThatNotSupport() {
"{\"query\": \"SELECT name FROM test1 JOIN test2 ON test1.name = test2.name\"}"),
"SELECT name FROM test1 JOIN test2 ON test1.name = test2.name",
QUERY_API_ENDPOINT,
"");
"", "");

RestSQLQueryAction queryAction = new RestSQLQueryAction(clusterService, settings, catalogService);
assertSame(NOT_SUPPORTED_YET, queryAction.prepareRequest(request, nodeClient));
Expand Down
13 changes: 11 additions & 2 deletions sql-cli/src/opensearch_sql_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
default="sql",
help="SQL OR PPL",
)
@click.option(
"--engine",
"engine",
type=click.STRING,
help="SQL engine to use: V2, V1 (aka legacy) or both"
)
def cli(
endpoint,
query,
Expand All @@ -83,6 +89,7 @@ def cli(
always_use_pager,
use_aws_authentication,
query_language,
engine
):
"""
Provide endpoint for OpenSearch client.
Expand All @@ -101,9 +108,10 @@ def cli(
opensearch_executor = OpenSearchConnection(endpoint, http_auth, use_aws_authentication)
opensearch_executor.set_connection()
if explain:
output = opensearch_executor.execute_query(query, explain=True, use_console=False)
output = opensearch_executor.execute_query(query, explain=True, use_console=False, engine=engine)
else:
output = opensearch_executor.execute_query(query, output_format=result_format, use_console=False)
output = opensearch_executor.execute_query(query, output_format=result_format, use_console=False,
engine=engine)
if output and result_format == "jdbc":
settings = OutputSettings(table_format="psql", is_vertical=is_vertical)
formatter = Formatter(settings)
Expand All @@ -119,6 +127,7 @@ def cli(
always_use_pager=always_use_pager,
use_aws_authentication=use_aws_authentication,
query_language=query_language,
engine=engine
)
opensearchsql_cli.connect(endpoint, http_auth)
opensearchsql_cli.run_cli()
Expand Down
9 changes: 6 additions & 3 deletions sql-cli/src/opensearch_sql_cli/opensearch_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,35 +146,38 @@ def handle_server_close_connection(self):
)
click.secho(repr(reconnection_err), err=True, fg="red")

def execute_query(self, query, output_format="jdbc", explain=False, use_console=True):
def execute_query(self, query, output_format="jdbc", explain=False, use_console=True, engine=None):
"""
Handle user input, send SQL query and get response.

:param use_console: use console to interact with user, otherwise it's single query
:param query: SQL query
:param output_format: jdbc/csv
:param explain: if True, use _explain API.
:param engine: SQL engine to use
:return: raw http response
"""

# TODO: consider add evaluator/handler to filter obviously-invalid input,
# to save cost of http client.
# deal with input
final_query = query.strip().strip(";")
params = None if explain else ({"format": output_format} if engine is None else
{"format": output_format, "engine": engine})

try:
if self.query_language == "sql":
data = self.client.transport.perform_request(
url="/_plugins/_sql/_explain" if explain else "/_plugins/_sql/",
method="POST",
params=None if explain else {"format": output_format},
params=params,
body={"query": final_query},
)
else:
data = self.client.transport.perform_request(
url="/_plugins/_ppl/_explain" if explain else "/_plugins/_ppl/",
method="POST",
params=None if explain else {"format": output_format},
params=params,
body={"query": final_query},
)
return data
Expand Down
7 changes: 5 additions & 2 deletions sql-cli/src/opensearch_sql_cli/opensearchsql_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@
class OpenSearchSqlCli:
"""OpenSearchSqlCli instance is used to build and run the OpenSearch SQL CLI."""

def __init__(self, clirc_file=None, always_use_pager=False, use_aws_authentication=False, query_language="sql"):
def __init__(self, clirc_file=None, always_use_pager=False, use_aws_authentication=False, query_language="sql",
engine=None):

# Load conf file
config = self.config = get_config(clirc_file)
literal = self.literal = self._get_literals()
Expand All @@ -49,6 +51,7 @@ def __init__(self, clirc_file=None, always_use_pager=False, use_aws_authenticati
self.query_language = query_language
self.always_use_pager = always_use_pager
self.use_aws_authentication = use_aws_authentication
self.engine = engine
self.keywords_list = literal["keywords"]
self.functions_list = literal["functions"]
self.syntax_style = config["main"]["syntax_style"]
Expand Down Expand Up @@ -125,7 +128,7 @@ def run_cli(self):
break # Control-D pressed.

try:
output = self.opensearch_executor.execute_query(text)
output = self.opensearch_executor.execute_query(text, engine=self.engine)
if output:
formatter = Formatter(settings)
formatted_output = formatter.format_output(output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class ConnectionImpl implements OpenSearchConnection, JdbcWrapper, Loggin
private String user;
private Logger log;
private int fetchSize;
private String engine;
private boolean open = false;
private Transport transport;
private Protocol protocol;
Expand All @@ -66,6 +67,7 @@ public ConnectionImpl(ConnectionConfig connectionConfig, TransportFactory transp
this.url = connectionConfig.getUrl();
this.user = connectionConfig.getUser();
this.fetchSize = connectionConfig.getFetchSize();
this.engine = connectionConfig.getEngine();

try {
this.transport = transportFactory.getTransport(connectionConfig, log, getUserAgent());
Expand Down Expand Up @@ -97,6 +99,10 @@ public int getFetchSize() {
return fetchSize;
}

public String getEngine() {
return engine;
}

@Override
public Statement createStatement() throws SQLException {
log.debug(() -> logEntry("createStatement()"));
Expand Down
Loading