Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
  • Loading branch information
pmpailis committed Oct 4, 2024
1 parent f97e5c1 commit 88b9384
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 102 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.search.retriever;

import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesRequest;
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.KnnVectorQueryBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.junit.Before;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
import static org.hamcrest.Matchers.equalTo;

@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0)
public class RetrieverTelemetryIT extends ESIntegTestCase {

private static final String INDEX_NAME = "test_index";

@Override
protected boolean addMockHttpTransport() {
return false; // enable http
}

@Before
public void setup() throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject("vector")
.field("type", "dense_vector")
.field("dims", 1)
.field("index", true)
.field("similarity", "l2_norm")
.startObject("index_options")
.field("type", "hnsw")
.endObject()
.endObject()
.startObject("text")
.field("type", "text")
.endObject()
.startObject("integer")
.field("type", "integer")
.endObject()
.startObject("topic")
.field("type", "keyword")
.endObject()
.endObject()
.endObject();

assertAcked(prepareCreate(INDEX_NAME).setMapping(builder));
ensureGreen(INDEX_NAME);
}

public void testTelemetryForRRFRetriever() throws IOException {

if (false == isRetrieverTelemetryEnabled()) {
return;
}

// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
Request request = new Request("GET", INDEX_NAME + "/_search");
SearchSourceBuilder source = new SearchSourceBuilder();
source.retriever(new KnnRetrieverBuilder("vector", new float[]{1.0f}, null, 10, 15, null));
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}

// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
// `queries`
{
Request request = new Request("GET", INDEX_NAME + "/_search");
SearchSourceBuilder source = new SearchSourceBuilder();
source.retriever(new StandardRetrieverBuilder(QueryBuilders.rangeQuery("integer").gte(2)));
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}

// search#3 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "knn" under
// `queries`
{
Request request = new Request("GET", INDEX_NAME + "/_search");
SearchSourceBuilder source = new SearchSourceBuilder();
source.retriever(new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", new float[]{1.0f}, 10, 15, null)));
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}

// search#4 - this will record 1 entry for "retriever" in `sections`, and 1 for "standard" under `retrievers`, and 1 for "term"
// under `queries`
{
Request request = new Request("GET", INDEX_NAME + "/_search");
SearchSourceBuilder source = new SearchSourceBuilder();
source.retriever(new StandardRetrieverBuilder(QueryBuilders.termQuery("topic", "foo")));
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}

// search#5 - this will record 1 entry for "knn" in `sections`
{
Request request = new Request("GET", INDEX_NAME + "/_search");
SearchSourceBuilder source = new SearchSourceBuilder();
source.knnSearch(List.of(new KnnSearchBuilder("vector", new float[]{1.0f}, 10, 15, null)));
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}

// search#6 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
{
Request request = new Request("GET", INDEX_NAME + "/_search");
SearchSourceBuilder source = new SearchSourceBuilder();
source.query(QueryBuilders.matchAllQuery());
request.setJsonEntity(Strings.toString(source));
getRestClient().performRequest(request);
}

// cluster stats
{
SearchUsageStats stats = clusterAdmin().prepareClusterStats().get().getIndicesStats().getSearchUsageStats();
assertEquals(6, stats.getTotalSearchCount());

assertThat(stats.getSectionsUsage().size(), equalTo(3));
assertThat(stats.getSectionsUsage().get("retriever"), equalTo(4L));
assertThat(stats.getSectionsUsage().get("query"), equalTo(1L));
assertThat(stats.getSectionsUsage().get("knn"), equalTo(1L));

assertThat(stats.getRetrieversUsage().size(), equalTo(3));
assertThat(stats.getRetrieversUsage().get("standard"), equalTo(3L));
assertThat(stats.getRetrieversUsage().get("knn"), equalTo(1L));

assertThat(stats.getQueryUsage().size(), equalTo(5));
assertThat(stats.getQueryUsage().get("range"), equalTo(1L));
assertThat(stats.getQueryUsage().get("term"), equalTo(1L));
assertThat(stats.getQueryUsage().get("match_all"), equalTo(1L));
assertThat(stats.getQueryUsage().get("knn"), equalTo(1L));
}
}

private boolean isRetrieverTelemetryEnabled() throws IOException {
NodesCapabilitiesResponse res = clusterAdmin().nodesCapabilities(
new NodesCapabilitiesRequest().method(RestRequest.Method.GET).path("_cluster/stats").capabilities("retrievers-usage-stats")
).actionGet();
return res != null && res.isSupported().orElse(false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.action.admin.cluster.node.capabilities.NodesCapabilitiesResponse;
import org.elasticsearch.action.admin.cluster.stats.SearchUsageStats;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
Expand Down

0 comments on commit 88b9384

Please sign in to comment.