Skip to content

Commit

Permalink
feat: add knnQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
KennyLindahl authored and Andreas Franzon committed Mar 8, 2024
1 parent 09f73cc commit 06a6c71
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/core/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ exports.Aggregation = require('./aggregation');

exports.Query = require('./query');

exports.Knn = require('./knn');

exports.Suggester = require('./suggester');

exports.Script = require('./script');
Expand Down
131 changes: 131 additions & 0 deletions src/core/knn.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
'use strict';

const { recursiveToJSON, checkType } = require('./util');
const Query = require('./query');

/**
* Class representing a k-Nearest Neighbors (k-NN) query.
* This class extends the Query class to support the specifics of k-NN search, including setting up the field,
* query vector, number of neighbors (k), and number of candidates.
*
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html)
*/
class Knn {
/**

Check warning on line 14 in src/core/knn.js

View workflow job for this annotation

GitHub Actions / check (10.x)

Missing JSDoc for parameter 'field'

Check warning on line 14 in src/core/knn.js

View workflow job for this annotation

GitHub Actions / check (12.x)

Missing JSDoc for parameter 'field'

Check warning on line 14 in src/core/knn.js

View workflow job for this annotation

GitHub Actions / check (14.x)

Missing JSDoc for parameter 'field'
* Creates an instance of Knn.
*/
constructor(field, k, numCandidates) {
if (k > numCandidates)
throw new Error('Knn numCandidates cannot be less than k');
this._body = {};
this._body.field = field;
this._body.k = k;
this._body.filter = [];
this._body.num_candidates = numCandidates;
}

/**
* Sets the query vector for the k-NN search.
* @param {Array<number>} vector - The query vector.
* @returns {Knn} Returns the instance of Knn for method chaining.
*/
queryVector(vector) {
if (this._body.query_vector_builder)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this._body.query_vector = vector;
return this;
}

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* It's important to note that either a direct query vector or a query vector builder can be
* provided, but not both.
*
* @param {string} modelId - The ID of the model to be used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @throws {Error} Throws an error if both query_vector_builder and query_vector are provided.
* @returns {Knn} Returns the instance of Knn for method chaining.
*
* Usage example:
* let knn = new Knn();
* knn.queryVectorBuilder('model_123', 'Sample model text');
*/
queryVectorBuilder(modelId, modelText) {
if (this._body.query_vector)
throw new Error(
'cannot provide both query_vector_builder and query_vector'
);
this.query_vector_builder = {
text_embeddings: {
model_id: modelId,
model_text: modelText
}
};
return this;
}

/**
* Adds one or more filter queries to the k-NN search.
*
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
* If any provided query is not an instance of `Query`, a TypeError is thrown.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {Knn} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*
* Usage example:
* let knn = new Knn();
* knn.filter(new TermQuery('field', 'value')); // Applying a single filter query
* knn.filter([new TermQuery('field1', 'value1'), new TermQuery('field2', 'value2')]); // Applying multiple filter queries
*/
filter(queries) {
const queryArray = Array.isArray(queries) ? queries : [queries];
queryArray.forEach(query => {
checkType(query, Query);
this._body.filter.push(query);
});
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} boost - The number of the boost
* @returns {Knn} Returns the instance of Knn for method chaining.
*/
boost(boost) {
this._body.boost = boost;
return this;
}

/**
* Sets the field to perform the k-NN search on.
* @param {number} similarity - The number of the similarity
* @returns {Knn} Returns the instance of Knn for method chaining.
*/
similarity(similarity) {
this._body.similarity = similarity;
return this;
}

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
* @returns {Object} returns an Object which maps to the elasticsearch query DSL
*/
toJSON() {
if (!this._body.query_vector && !this._body.query_vector_builder)
throw new Error(
'either query_vector_builder or query_vector must be provided'
);
return recursiveToJSON(this._body);
}
}

module.exports = Knn;
21 changes: 20 additions & 1 deletion src/core/request-body-search.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ const Query = require('./query'),
Rescore = require('./rescore'),
Sort = require('./sort'),
Highlight = require('./highlight'),
InnerHits = require('./inner-hits');
InnerHits = require('./inner-hits'),
Knn = require('./knn');

const { checkType, setDefault, recursiveToJSON } = require('./util');

Expand Down Expand Up @@ -69,6 +70,7 @@ class RequestBodySearch {
constructor() {
// Maybe accept some optional parameter?
this._body = {};
this._knn = [];
this._aggs = [];
this._suggests = [];
this._suggestText = null;
Expand All @@ -87,6 +89,21 @@ class RequestBodySearch {
return this;
}

/**
* Sets knn on the search request body.
*
* @param {knn} knn
* @returns {RequestBodySearch} returns `this` so that calls can be chained.
*/
knn(knn) {
const knns = Array.isArray(knn) ? knn : [knn];
knns.forEach(_knn => {
checkType(_knn, Knn);
this._knn.push(_knn);
});
return this;
}

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -785,6 +802,8 @@ class RequestBodySearch {
toJSON() {
const dsl = recursiveToJSON(this._body);

if (!isEmpty(this._knn)) dsl.knn = this._knn;

if (!isEmpty(this._aggs)) dsl.aggs = recMerge(this._aggs);

if (!isEmpty(this._suggests) || !isNil(this._suggestText)) {
Expand Down
87 changes: 86 additions & 1 deletion src/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ declare namespace esb {
*/
query(query: Query): this;

/**
* Sets knn on the request body.
*
* @param {Knn} knn
*/
knn(knn: Knn | Knn[]): this;

/**
* Sets aggregation on the request body.
* Alias for method `aggregation`
Expand Down Expand Up @@ -3074,7 +3081,7 @@ declare namespace esb {

/**
* Sets the script used to compute the score of documents returned by the query.
*
*
* @param {Script} script A valid `Script` object
*/
script(script: Script): this;
Expand Down Expand Up @@ -3614,6 +3621,84 @@ declare namespace esb {
spanQry?: SpanQueryBase
): SpanFieldMaskingQuery;

/**
* Knn performs k-nearest neighbor (KNN) searches.
* This class allows configuring the KNN search with various parameters such as field, query vector,
* number of nearest neighbors (k), number of candidates, boost factor, and similarity metric.
*
* NOTE: Only available in Elasticsearch v8.0+
*/
export class Knn {
/**
* Creates an instance of Knn, initializing the internal state for the k-NN search.
*
* @param {string} field - (Optional) The field against which to perform the k-NN search.
* @param {number} k - (Optional) The number of nearest neighbors to retrieve.
* @param {number} numCandidates - (Optional) The number of candidate neighbors to consider during the search.
* @throws {Error} If the number of candidates (numCandidates) is less than the number of neighbors (k).
*/
constructor(field: string, k: number, numCandidates: number);

/**
* Sets the query vector for the KNN search, an array of numbers representing the reference point.
*
* @param {number[]} vector
*/
queryVector(vector: number[]): this;

/**
* Sets the query vector builder for the k-NN search.
* This method configures a query vector builder using a specified model ID and model text.
* Note that either a direct query vector or a query vector builder can be provided, but not both.
*
* @param {string} modelId - The ID of the model used for generating the query vector.
* @param {string} modelText - The text input based on which the query vector is generated.
* @throws {Error} If both query_vector_builder and query_vector are provided.
* @returns {Knn} Returns the instance of Knn for method chaining.
*/
queryVectorBuilder(modelId: string, modelText: string): this;

/**
* Adds one or more filter queries to the k-NN search.
* This method is designed to apply filters to the k-NN search. It accepts either a single
* query or an array of queries. Each query acts as a filter, refining the search results
* according to the specified conditions. These queries must be instances of the `Query` class.
*
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering.
* @returns {Knn} Returns `this` to allow method chaining.
* @throws {TypeError} If any of the provided queries is not an instance of `Query`.
*/
filter(queries: Query | Query[]): this;

/**
* Applies a boost factor to the query to influence the relevance score of returned documents.
*
* @param {number} boost
*/
boost(boost: number): this;

/**
* Sets the similarity metric used in the KNN algorithm to calculate similarity.
*
* @param {number} similarity
*/
similarity(similarity: number): this;

/**
* Override default `toJSON` to return DSL representation for the `query`
*
* @override
*/
toJSON(): object;
}

/**
* Factory function to instantiate a new Knn object.
*
* @returns {Knn}
*/
export function knn(field: string, k: number, numCandidates: number): Knn;

/**
* Base class implementation for all aggregation types.
*
Expand Down
8 changes: 8 additions & 0 deletions src/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const {
InnerHits,
SearchTemplate,
Query,
Knn,
util: { constructorWrapper }
} = require('./core');

Expand Down Expand Up @@ -333,6 +334,13 @@ exports.spanWithinQuery = constructorWrapper(SpanWithinQuery);

exports.SpanFieldMaskingQuery = SpanFieldMaskingQuery;
exports.spanFieldMaskingQuery = constructorWrapper(SpanFieldMaskingQuery);

/* ============ ============ ============ */
/* ======== Knn ======== */
/* ============ ============ ============ */
exports.Knn = Knn;
exports.knn = constructorWrapper(Knn);

/* ============ ============ ============ */
/* ======== Metrics Aggregations ======== */
/* ============ ============ ============ */
Expand Down
75 changes: 75 additions & 0 deletions test/core-test/knn.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import test from 'ava';
import { Knn, TermQuery } from '../../src';

test('knn can be instantiated', t => {
const knn = new Knn('my_field', 5, 10);
t.truthy(knn);
});

test('knn throws error if numCandidates is less than k', t => {
const error = t.throws(() => new Knn('my_field', 10, 5));
t.is(error.message, 'Knn numCandidates cannot be less than k');
});

test('knn queryVector sets correctly', t => {
const vector = [1, 2, 3];
const knn = new Knn('my_field', 5, 10).queryVector(vector);
t.deepEqual(knn._body.query_vector, vector);
});

test('knn queryVectorBuilder sets correctly', t => {
const modelId = 'model_123';
const modelText = 'Sample model text';
const knn = new Knn('my_field', 5, 10).queryVectorBuilder(
modelId,
modelText
);
t.deepEqual(knn.query_vector_builder.text_embeddings, {
model_id: modelId,
model_text: modelText
});
});

test('knn filter method adds queries correctly', t => {
const knn = new Knn('my_field', 5, 10);
const query = new TermQuery('field', 'value');
knn.filter(query);
t.deepEqual(knn._body.filter, [query]);
});

test('knn boost method sets correctly', t => {
const boostValue = 1.5;
const knn = new Knn('my_field', 5, 10).boost(boostValue);
t.is(knn._body.boost, boostValue);
});

test('knn similarity method sets correctly', t => {
const similarityValue = 0.8;
const knn = new Knn('my_field', 5, 10).similarity(similarityValue);
t.is(knn._body.similarity, similarityValue);
});

test('knn toJSON method returns correct DSL', t => {
const knn = new Knn('my_field', 5, 10)
.queryVector([1, 2, 3])
.filter(new TermQuery('field', 'value'));

const expectedDSL = {
field: 'my_field',
k: 5,
num_candidates: 10,
query_vector: [1, 2, 3],
filter: [{ term: { field: 'value' } }]
};

t.deepEqual(knn.toJSON(), expectedDSL);
});

test('knn toJSON throws error if neither query_vector nor query_vector_builder is provided', t => {
const knn = new Knn('my_field', 5, 10);
const error = t.throws(() => knn.toJSON());
t.is(
error.message,
'either query_vector_builder or query_vector must be provided'
);
});

0 comments on commit 06a6c71

Please sign in to comment.