Skip to content

Commit

Permalink
Fix refactoring, spotless
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 14, 2024
1 parent 4497b92 commit 955da1f
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ public final class KnnRetrieverBuilder extends RetrieverBuilder {
(QueryVectorBuilder) args[2],
(int) args[3],
(int) args[4],
(RescoreVectorBuilder) args[6], (Float) args[5]
(RescoreVectorBuilder) args[6],
(Float) args[5]
);
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public static KnnRetrieverBuilder createRandomKnnRetrieverBuilder() {
null,
k,
numCands,
rescoreVectorBuilder, similarity
rescoreVectorBuilder,
similarity
);

List<QueryBuilder> preFilterQueryBuilders = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ private List<RetrieverBuilder> innerRetrievers(QueryRewriteContext queryRewriteC
null,
randomInt(10),
randomIntBetween(10, 100),
randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)), randomFloat()
randomBoolean() ? null : new RescoreVectorBuilder(randomFloatBetween(1.0f, 10.0f, false)),
randomFloat()
);
if (randomBoolean()) {
knnRetrieverBuilder.preFilterQueryBuilders = preFilters(queryRewriteContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ public void testTelemetryForRRFRetriever() throws IOException {

// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null)));
performSearch(
new SearchSourceBuilder().retriever(
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null)
)
);
}

// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
Expand Down Expand Up @@ -146,7 +150,9 @@ public void testTelemetryForRRFRetriever() throws IOException {

// search#6 - this will record 1 entry for "knn" in `sections`
{
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
performSearch(
new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)))
);
}

// search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,15 @@ public void testRRFPagination() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(
VECTOR_FIELD,
new float[] { 2.0f },
null,
10,
100,
null,
null
);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand Down Expand Up @@ -233,7 +241,7 @@ public void testRRFWithAggs() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand Down Expand Up @@ -288,7 +296,7 @@ public void testRRFWithCollapse() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand Down Expand Up @@ -345,7 +353,7 @@ public void testRRFRetrieverWithCollapseAndAggs() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand Down Expand Up @@ -411,7 +419,7 @@ public void testMultipleRRFRetrievers() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand All @@ -430,7 +438,7 @@ public void testMultipleRRFRetrievers() {
),
// this one bring just doc 7 which should be ranked first eventually
new CompoundRetrieverBuilder.RetrieverSource(
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null),
new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 7.0f }, null, 1, 100, null, null),
null
)
),
Expand Down Expand Up @@ -477,7 +485,7 @@ public void testRRFExplainWithNamedRetrievers() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand Down Expand Up @@ -536,7 +544,7 @@ public void testRRFExplainWithAnotherNestedRRF() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 2, 3, 6, and 7
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 2.0f }, null, 10, 100, null, null);

RRFRetrieverBuilder nestedRRF = new RRFRetrieverBuilder(
Arrays.asList(
Expand Down Expand Up @@ -773,7 +781,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
throw new IllegalStateException("Should not be called");
}
};
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, (RescoreVectorBuilder) args[6], null);
var knn = new KnnRetrieverBuilder("vector", null, vectorBuilder, 10, 10, null, null);
var standard = new StandardRetrieverBuilder(new KnnVectorQueryBuilder("vector", vectorBuilder, 10, 10, null));
var rrf = new RRFRetrieverBuilder(
List.of(new CompoundRetrieverBuilder.RetrieverSource(knn, null), new CompoundRetrieverBuilder.RetrieverSource(standard, null)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public void testRRFRetrieverWithNestedQuery() {
);
standard1.getPreFilterQueryBuilders().add(QueryBuilders.queryStringQuery("search").defaultField(TEXT_FIELD));
// this one retrieves docs 6
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, (RescoreVectorBuilder) args[6], null);
KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(VECTOR_FIELD, new float[] { 6.0f }, null, 1, 100, null, null);
source.retriever(
new RRFRetrieverBuilder(
Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ public void testTelemetryForRRFRetriever() throws IOException {

// search#1 - this will record 1 entry for "retriever" in `sections`, and 1 for "knn" under `retrievers`
{
performSearch(new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null)));
performSearch(
new SearchSourceBuilder().retriever(new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null))
);
}

// search#2 - this will record 1 entry for "retriever" in `sections`, 1 for "standard" under `retrievers`, and 1 for "range" under
Expand Down Expand Up @@ -136,7 +138,7 @@ public void testTelemetryForRRFRetriever() throws IOException {
new RRFRetrieverBuilder(
Arrays.asList(
new CompoundRetrieverBuilder.RetrieverSource(
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, (RescoreVectorBuilder) args[6], null),
new KnnRetrieverBuilder("vector", new float[] { 1.0f }, null, 10, 15, null, null),
null
),
new CompoundRetrieverBuilder.RetrieverSource(
Expand All @@ -153,7 +155,9 @@ public void testTelemetryForRRFRetriever() throws IOException {

// search#6 - this will record 1 entry for "knn" in `sections`
{
performSearch(new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null))));
performSearch(
new SearchSourceBuilder().knnSearch(List.of(new KnnSearchBuilder("vector", new float[] { 1.0f }, 10, 15, null, null)))
);
}

// search#7 - this will record 1 entry for "query" in `sections`, and 1 for "match_all" under `queries`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public RetrieverBuilder toRetriever(SearchSourceBuilder source, Predicate<NodeFe
knnSearchBuilder.getNumCands(),
knnSearchBuilder.getRescoreVectorBuilder(),
knnSearchBuilder.getSimilarity()
);
);
knnRetriever.retrieverName(knnSearchBuilder.queryName());
retrieverSources.add(new CompoundRetrieverBuilder.RetrieverSource(knnRetriever, null));
}
Expand Down

0 comments on commit 955da1f

Please sign in to comment.