Skip to content

Commit

Permalink
Semantic search with query builder rewrite (elastic#118676) (elastic#…
Browse files Browse the repository at this point in the history
…118945)

* Semantic search with query builder rewrite

* Address review feedback

* Add feature behind snapshot

* Use after/before instead of afterClass/beforeClass

* Call onFailure instead of throwing exception

* Fix KqlFunctionIT by requiring KqlPlugin

* Update scoring tests now that they are enabled

* Drop the score column for now
  • Loading branch information
ioanatia authored Dec 18, 2024
1 parent f60d0f8 commit 4a99215
Show file tree
Hide file tree
Showing 18 changed files with 629 additions and 7 deletions.
20 changes: 18 additions & 2 deletions server/src/main/java/org/elasticsearch/action/ResolvedIndices.java
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,26 @@ public static ResolvedIndices resolveWithIndicesRequest(
RemoteClusterService remoteClusterService,
long startTimeInMillis
) {
final Map<String, OriginalIndices> remoteClusterIndices = remoteClusterService.groupIndices(
return resolveWithIndexNamesAndOptions(
request.indices(),
request.indicesOptions(),
request.indices()
clusterState,
indexNameExpressionResolver,
remoteClusterService,
startTimeInMillis
);
}

public static ResolvedIndices resolveWithIndexNamesAndOptions(
String[] indexNames,
IndicesOptions indicesOptions,
ClusterState clusterState,
IndexNameExpressionResolver indexNameExpressionResolver,
RemoteClusterService remoteClusterService,
long startTimeInMillis
) {
final Map<String, OriginalIndices> remoteClusterIndices = remoteClusterService.groupIndices(indicesOptions, indexNames);

final OriginalIndices localIndices = remoteClusterIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY);

Index[] concreteLocalIndices = localIndices == null
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.qa.multi_node;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;

import org.elasticsearch.test.TestClustersThreadFilter;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.xpack.esql.qa.rest.SemanticMatchTestCase;
import org.junit.ClassRule;

@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
public class SemanticMatchIT extends SemanticMatchTestCase {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.qa.single_node;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;

import org.elasticsearch.test.TestClustersThreadFilter;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.xpack.esql.qa.rest.SemanticMatchTestCase;
import org.junit.ClassRule;

@ThreadLeakFilters(filters = TestClustersThreadFilter.class)
public class SemanticMatchIT extends SemanticMatchTestCase {
@ClassRule
public static ElasticsearchCluster cluster = Clusters.testCluster(spec -> spec.plugin("inference-service-test"));

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql.qa.rest;

import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xpack.esql.action.EsqlCapabilities;
import org.junit.After;
import org.junit.Before;

import java.io.IOException;
import java.util.Map;

import static org.hamcrest.core.StringContains.containsString;

public abstract class SemanticMatchTestCase extends ESRestTestCase {
public void testWithMultipleInferenceIds() throws IOException {
String query = """
from test-semantic1,test-semantic2
| where match(semantic_text_field, "something")
""";
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));

assertThat(re.getMessage(), containsString("Field [semantic_text_field] has multiple inference IDs associated with it"));

assertEquals(400, re.getResponse().getStatusLine().getStatusCode());
}

public void testWithInferenceNotConfigured() {
String query = """
from test-semantic3
| where match(semantic_text_field, "something")
""";
ResponseException re = expectThrows(ResponseException.class, () -> runEsqlQuery(query));

assertThat(re.getMessage(), containsString("Inference endpoint not found"));
assertEquals(404, re.getResponse().getStatusLine().getStatusCode());
}

@Before
public void setUpIndices() throws IOException {
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_TYPE.isEnabled());

var settings = Settings.builder().build();

String mapping1 = """
"properties": {
"semantic_text_field": {
"type": "semantic_text",
"inference_id": "test_sparse_inference"
}
}
""";
createIndex(adminClient(), "test-semantic1", settings, mapping1);

String mapping2 = """
"properties": {
"semantic_text_field": {
"type": "semantic_text",
"inference_id": "test_dense_inference"
}
}
""";
createIndex(adminClient(), "test-semantic2", settings, mapping2);

String mapping3 = """
"properties": {
"semantic_text_field": {
"type": "semantic_text",
"inference_id": "inexistent"
}
}
""";
createIndex(adminClient(), "test-semantic3", settings, mapping3);
}

@Before
public void setUpTextEmbeddingInferenceEndpoint() throws IOException {
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_TYPE.isEnabled());
Request request = new Request("PUT", "_inference/text_embedding/test_dense_inference");
request.setJsonEntity("""
{
"service": "test_service",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
},
"task_settings": {
}
}
""");
adminClient().performRequest(request);
}

@After
public void wipeData() throws IOException {
assumeTrue("semantic text capability not available", EsqlCapabilities.Cap.SEMANTIC_TEXT_TYPE.isEnabled());
adminClient().performRequest(new Request("DELETE", "*"));

try {
adminClient().performRequest(new Request("DELETE", "_inference/test_dense_inference"));
} catch (ResponseException e) {
// 404 here means the endpoint was not created
if (e.getResponse().getStatusLine().getStatusCode() != 404) {
throw e;
}
}
}

private Map<String, Object> runEsqlQuery(String query) throws IOException {
RestEsqlTestCase.RequestObjectBuilder builder = RestEsqlTestCase.requestObjectBuilder().query(query);
return RestEsqlTestCase.runEsqlSync(builder);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.elasticsearch.xpack.esql.plugin.EsqlPlugin;
import org.elasticsearch.xpack.esql.plugin.QueryPragmas;
import org.elasticsearch.xpack.esql.session.Configuration;
import org.elasticsearch.xpack.esql.session.QueryBuilderResolver;
import org.elasticsearch.xpack.esql.stats.Metrics;
import org.elasticsearch.xpack.esql.stats.SearchStats;
import org.elasticsearch.xpack.versionfield.Version;
Expand Down Expand Up @@ -351,6 +352,8 @@ public String toString() {

public static final Verifier TEST_VERIFIER = new Verifier(new Metrics(new EsqlFunctionRegistry()), new XPackLicenseState(() -> 0L));

public static final QueryBuilderResolver MOCK_QUERY_BUILDER_RESOLVER = new MockQueryBuilderResolver();

private EsqlTestUtils() {}

public static Configuration configuration(QueryPragmas pragmas, String query) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.esql;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.esql.session.QueryBuilderResolver;
import org.elasticsearch.xpack.esql.session.Result;

import java.util.function.BiConsumer;

public class MockQueryBuilderResolver extends QueryBuilderResolver {
public MockQueryBuilderResolver() {
super(null, null, null, null);
}

@Override
public void resolveQueryBuilders(
LogicalPlan plan,
ActionListener<Result> listener,
BiConsumer<LogicalPlan, ActionListener<Result>> callback
) {
callback.accept(plan, listener);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -597,3 +597,74 @@ from employees,employees_incompatible

emp_no_bool:boolean
;

testMatchWithSemanticText
required_capability: match_function
required_capability: semantic_text_type

from semantic_text
| where match(semantic_text_field, "something")
| keep semantic_text_field
| sort semantic_text_field asc
;

semantic_text_field:semantic_text
all we have to decide is what to do with the time that is given to us
be excellent to each other
live long and prosper
;

testMatchWithSemanticTextAndKeyword
required_capability: match_function
required_capability: semantic_text_type

from semantic_text
| where match(semantic_text_field, "something") AND match(host, "host1")
| keep semantic_text_field, host
;

semantic_text_field:semantic_text | host:keyword
live long and prosper | host1
;

testMatchWithSemanticTextMultiValueField
required_capability: match_function
required_capability: semantic_text_type

from semantic_text metadata _id
| where match(st_multi_value, "something") AND match(host, "host1")
| keep _id, st_multi_value
;

_id: keyword | st_multi_value:semantic_text
1 | ["Hello there!", "This is a random value", "for testing purposes"]
;

testMatchWithSemanticTextWithEvalsAndOtherFunctionsAndStats
required_capability: match_function
required_capability: semantic_text_type

from semantic_text
| where qstr("description:some*")
| eval size = mv_count(st_multi_value)
| where match(semantic_text_field, "something") AND size > 1 AND match(host, "host1")
| STATS result = count(*)
;

result:long
1
;

testMatchWithSemanticTextAndKql
required_capability: match_function
required_capability: semantic_text_type
required_capability: kql_function

from semantic_text
| where kql("host:host1") AND match(semantic_text_field, "something")
| KEEP host, semantic_text_field
;

host:keyword | semantic_text_field:semantic_text
"host1" | live long and prosper
;
Loading

0 comments on commit 4a99215

Please sign in to comment.