-
Notifications
You must be signed in to change notification settings - Fork 77
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
09f73cc
commit 06a6c71
Showing
6 changed files
with
322 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
'use strict'; | ||
|
||
const { recursiveToJSON, checkType } = require('./util'); | ||
const Query = require('./query'); | ||
|
||
/** | ||
* Class representing a k-Nearest Neighbors (k-NN) query. | ||
* This class extends the Query class to support the specifics of k-NN search, including setting up the field, | ||
* query vector, number of neighbors (k), and number of candidates. | ||
* | ||
* [Elasticsearch reference](https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html) | ||
*/ | ||
class Knn { | ||
/** | ||
Check warning on line 14 in src/core/knn.js GitHub Actions / check (10.x)
Check warning on line 14 in src/core/knn.js GitHub Actions / check (12.x)
|
||
* Creates an instance of Knn. | ||
*/ | ||
constructor(field, k, numCandidates) { | ||
if (k > numCandidates) | ||
throw new Error('Knn numCandidates cannot be less than k'); | ||
this._body = {}; | ||
this._body.field = field; | ||
this._body.k = k; | ||
this._body.filter = []; | ||
this._body.num_candidates = numCandidates; | ||
} | ||
|
||
/** | ||
* Sets the query vector for the k-NN search. | ||
* @param {Array<number>} vector - The query vector. | ||
* @returns {Knn} Returns the instance of Knn for method chaining. | ||
*/ | ||
queryVector(vector) { | ||
if (this._body.query_vector_builder) | ||
throw new Error( | ||
'cannot provide both query_vector_builder and query_vector' | ||
); | ||
this._body.query_vector = vector; | ||
return this; | ||
} | ||
|
||
/** | ||
* Sets the query vector builder for the k-NN search. | ||
* This method configures a query vector builder using a specified model ID and model text. | ||
* It's important to note that either a direct query vector or a query vector builder can be | ||
* provided, but not both. | ||
* | ||
* @param {string} modelId - The ID of the model to be used for generating the query vector. | ||
* @param {string} modelText - The text input based on which the query vector is generated. | ||
* @throws {Error} Throws an error if both query_vector_builder and query_vector are provided. | ||
* @returns {Knn} Returns the instance of Knn for method chaining. | ||
* | ||
* Usage example: | ||
* let knn = new Knn(); | ||
* knn.queryVectorBuilder('model_123', 'Sample model text'); | ||
*/ | ||
queryVectorBuilder(modelId, modelText) { | ||
if (this._body.query_vector) | ||
throw new Error( | ||
'cannot provide both query_vector_builder and query_vector' | ||
); | ||
this.query_vector_builder = { | ||
text_embeddings: { | ||
model_id: modelId, | ||
model_text: modelText | ||
} | ||
}; | ||
return this; | ||
} | ||
|
||
/** | ||
* Adds one or more filter queries to the k-NN search. | ||
* | ||
* This method is designed to apply filters to the k-NN search. It accepts either a single | ||
* query or an array of queries. Each query acts as a filter, refining the search results | ||
* according to the specified conditions. These queries must be instances of the `Query` class. | ||
* If any provided query is not an instance of `Query`, a TypeError is thrown. | ||
* | ||
* @param {Query|Query[]} queries - A single `Query` instance or an array of `Query` instances for filtering. | ||
* @returns {Knn} Returns `this` to allow method chaining. | ||
* @throws {TypeError} If any of the provided queries is not an instance of `Query`. | ||
* | ||
* Usage example: | ||
* let knn = new Knn(); | ||
* knn.filter(new TermQuery('field', 'value')); // Applying a single filter query | ||
* knn.filter([new TermQuery('field1', 'value1'), new TermQuery('field2', 'value2')]); // Applying multiple filter queries | ||
*/ | ||
filter(queries) { | ||
const queryArray = Array.isArray(queries) ? queries : [queries]; | ||
queryArray.forEach(query => { | ||
checkType(query, Query); | ||
this._body.filter.push(query); | ||
}); | ||
return this; | ||
} | ||
|
||
/** | ||
* Sets the field to perform the k-NN search on. | ||
* @param {number} boost - The number of the boost | ||
* @returns {Knn} Returns the instance of Knn for method chaining. | ||
*/ | ||
boost(boost) { | ||
this._body.boost = boost; | ||
return this; | ||
} | ||
|
||
/** | ||
* Sets the field to perform the k-NN search on. | ||
* @param {number} similarity - The number of the similarity | ||
* @returns {Knn} Returns the instance of Knn for method chaining. | ||
*/ | ||
similarity(similarity) { | ||
this._body.similarity = similarity; | ||
return this; | ||
} | ||
|
||
/** | ||
* Override default `toJSON` to return DSL representation for the `query` | ||
* | ||
* @override | ||
* @returns {Object} returns an Object which maps to the elasticsearch query DSL | ||
*/ | ||
toJSON() { | ||
if (!this._body.query_vector && !this._body.query_vector_builder) | ||
throw new Error( | ||
'either query_vector_builder or query_vector must be provided' | ||
); | ||
return recursiveToJSON(this._body); | ||
} | ||
} | ||
|
||
module.exports = Knn; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import test from 'ava'; | ||
import { Knn, TermQuery } from '../../src'; | ||
|
||
test('knn can be instantiated', t => { | ||
const knn = new Knn('my_field', 5, 10); | ||
t.truthy(knn); | ||
}); | ||
|
||
test('knn throws error if numCandidates is less than k', t => { | ||
const error = t.throws(() => new Knn('my_field', 10, 5)); | ||
t.is(error.message, 'Knn numCandidates cannot be less than k'); | ||
}); | ||
|
||
test('knn queryVector sets correctly', t => { | ||
const vector = [1, 2, 3]; | ||
const knn = new Knn('my_field', 5, 10).queryVector(vector); | ||
t.deepEqual(knn._body.query_vector, vector); | ||
}); | ||
|
||
test('knn queryVectorBuilder sets correctly', t => { | ||
const modelId = 'model_123'; | ||
const modelText = 'Sample model text'; | ||
const knn = new Knn('my_field', 5, 10).queryVectorBuilder( | ||
modelId, | ||
modelText | ||
); | ||
t.deepEqual(knn.query_vector_builder.text_embeddings, { | ||
model_id: modelId, | ||
model_text: modelText | ||
}); | ||
}); | ||
|
||
test('knn filter method adds queries correctly', t => { | ||
const knn = new Knn('my_field', 5, 10); | ||
const query = new TermQuery('field', 'value'); | ||
knn.filter(query); | ||
t.deepEqual(knn._body.filter, [query]); | ||
}); | ||
|
||
test('knn boost method sets correctly', t => { | ||
const boostValue = 1.5; | ||
const knn = new Knn('my_field', 5, 10).boost(boostValue); | ||
t.is(knn._body.boost, boostValue); | ||
}); | ||
|
||
test('knn similarity method sets correctly', t => { | ||
const similarityValue = 0.8; | ||
const knn = new Knn('my_field', 5, 10).similarity(similarityValue); | ||
t.is(knn._body.similarity, similarityValue); | ||
}); | ||
|
||
test('knn toJSON method returns correct DSL', t => { | ||
const knn = new Knn('my_field', 5, 10) | ||
.queryVector([1, 2, 3]) | ||
.filter(new TermQuery('field', 'value')); | ||
|
||
const expectedDSL = { | ||
field: 'my_field', | ||
k: 5, | ||
num_candidates: 10, | ||
query_vector: [1, 2, 3], | ||
filter: [{ term: { field: 'value' } }] | ||
}; | ||
|
||
t.deepEqual(knn.toJSON(), expectedDSL); | ||
}); | ||
|
||
test('knn toJSON throws error if neither query_vector nor query_vector_builder is provided', t => { | ||
const knn = new Knn('my_field', 5, 10); | ||
const error = t.throws(() => knn.toJSON()); | ||
t.is( | ||
error.message, | ||
'either query_vector_builder or query_vector must be provided' | ||
); | ||
}); |