Skip to content

Commit

Permalink
Added new RestActions
Browse files Browse the repository at this point in the history
_aknn_search_vec, _aknn_clear_cache
  • Loading branch information
SthPhoenix authored Feb 22, 2019
1 parent c356cbc commit 89ca31a
Showing 1 changed file with 153 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,11 @@ public class AknnRestAction extends BaseRestHandler {
private final Integer K2_DEFAULT = 10;
private final Boolean RESCORE_DEFAULT = true;
private final Integer MINIMUM_DEFAULT = 1;

// TODO: add an option to the index endpoint handler that empties the cache.
// TODO: add an option to the index endpoint handler that empties the cache.
private Map<String, LshModel> lshModelCache = new HashMap<>();

@Inject
@Inject
public AknnRestAction(Settings settings, RestController controller) {
super(settings);
controller.registerHandler(GET, "/{index}/{type}/{id}/" + NAME_SEARCH, this);
Expand All @@ -87,17 +87,21 @@ public AknnRestAction(Settings settings, RestController controller) {
controller.registerHandler(GET, NAME_CLEAR_CACHE, this);
}

@Override
// @Override
public String getName() {
return NAME;
}

@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException {
if (restRequest.path().endsWith(NAME_SEARCH))
if (restRequest.path().endsWith(NAME_SEARCH_VEC))
return handleSearchVecRequest(restRequest, client);
else if (restRequest.path().endsWith(NAME_SEARCH))
return handleSearchRequest(restRequest, client);
else if (restRequest.path().endsWith(NAME_INDEX))
return handleIndexRequest(restRequest, client);
else if (restRequest.path().endsWith(NAME_CLEAR_CACHE))
return handleClearRequest(restRequest, client);
else
return handleCreateRequest(restRequest, client);
}
Expand All @@ -108,8 +112,9 @@ public static Double euclideanDistance(List<Double> A, List<Double> B) {
squaredDistance += Math.pow(A.get(i) - B.get(i), 2);
return Math.sqrt(squaredDistance);
}

// Loading LSH model refactored as function


// Loading LSH model refactored as function
//TODO Fix issues with stopwatch
public LshModel InitLsh(String aknnURI, NodeClient client) {
LshModel lshModel;
Expand Down Expand Up @@ -140,8 +145,8 @@ public LshModel InitLsh(String aknnURI, NodeClient client) {
}
return lshModel;
}
// Query execution refactored as function and added wrapper query

// Query execution refactored as function and added wrapper query
private List<Map<String, Object>> QueryLsh(List<Double> queryVector, Map<String, Long> queryHashes, String index, String type, Integer k1, Boolean rescore, String filterString, Integer minimum_should_match, Boolean debug, NodeClient client) {
// Retrieve the documents with most matching hashes. https://stackoverflow.com/questions/10773581
StopWatch stopWatch = new StopWatch("StopWatch to query LSH cache");
Expand All @@ -168,6 +173,7 @@ private List<Map<String, Object>> QueryLsh(List<Double> queryVector, Map<String,
hashes = null;
}


logger.info("Execute boolean search");
stopWatch.start("Execute boolean search");
SearchResponse approximateSearchResponse = client
Expand All @@ -192,8 +198,6 @@ private List<Map<String, Object>> QueryLsh(List<Double> queryVector, Map<String,
hitSource.remove(VECTOR_KEY);
hitSource.remove(HASHES_KEY);
}

// TODO: refactor code below
if (rescore) {
modifiedSortedHits.add(new HashMap<String, Object>() {{
put("_index", hit.getIndex());
Expand All @@ -203,6 +207,7 @@ private List<Map<String, Object>> QueryLsh(List<Double> queryVector, Map<String,
put("_source", hitSource);
}});
} else {

modifiedSortedHits.add(new HashMap<String, Object>() {{
put("_index", hit.getIndex());
put("_id", hit.getId());
Expand All @@ -227,10 +232,9 @@ private List<Map<String, Object>> QueryLsh(List<Double> queryVector, Map<String,
return modifiedSortedHits;
}

// LSH query part refactored as function and added more parameters:


private RestChannelConsumer handleSearchRequest(RestRequest restRequest, NodeClient client) throws IOException {
/**
/**
* Original handleSearchRequest() refactored for further reusability
* and added some additional parameters, such as filter query.
*
Expand All @@ -246,7 +250,7 @@ private RestChannelConsumer handleSearchRequest(RestRequest restRequest, NodeCli
* @param debug If set to 'True' will include original vectors and hashes in hits
* @return Return search hits
*/

StopWatch stopWatch = new StopWatch("StopWatch to Time Search Request");

// Parse request parameters.
Expand Down Expand Up @@ -304,27 +308,123 @@ private RestChannelConsumer handleSearchRequest(RestRequest restRequest, NodeCli
};
}

private RestChannelConsumer handleSearchVecRequest(RestRequest restRequest, NodeClient client) throws IOException {

/**
* Hybrid of refactored handleSearchRequest() and handleIndexRequest()
* Takes document containing query vector, hashes it, and executing query
* without indexing.
*
* @param index Index name
* @param type Doc type (keep in mind forthcoming _type removal in ES7)
* @param _aknn_vector Query vector
* @param filter String in format of ES bool query filter (excluding
* parent 'filter' node)
* @param k1 Number of candidates for scoring
* @param k2 Number of hits returned
* @param minimum_should_match number of hashes should match for hit to be returned
* @param rescore If set to 'True' will return results without exact matching stage
* @param debug If set to 'True' will include original vectors and hashes in hits
* @param clear_cache Force update LSH model cache before executing hashing.
* @return Return search hits
*/


StopWatch stopWatch = new StopWatch("StopWatch to Time Search Request");

// Parse request parameters.
stopWatch.start("Parse request parameters");
XContentParser xContentParser = XContentHelper.createParser(
restRequest.getXContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
restRequest.content(),
restRequest.getXContentType());
@SuppressWarnings("unchecked")
Map<String, Object> contentMap = xContentParser.mapOrdered();
@SuppressWarnings("unchecked")
Map<String, Object> aknnQueryMap = (Map<String, Object>) contentMap.get("query_aknn");
@SuppressWarnings("unchecked")
Map<String, ?> filterMap = (Map<String, ?>) contentMap.get("filter");
String filter = null;
if (filterMap != null) {
XContentBuilder filterBuilder = XContentFactory.jsonBuilder()
.map(filterMap);
filter = Strings.toString(filterBuilder);
}

final String index = (String) contentMap.get("_index");
final String type = (String) contentMap.get("_type");
final String aknnURI = (String) contentMap.get("_aknn_uri");
final Integer k1 = (Integer) aknnQueryMap.get("k1");
final Integer k2 = (Integer) aknnQueryMap.get("k2");
final Integer minimum_should_match = restRequest.paramAsInt("minimum_should_match", MINIMUM_DEFAULT);
final Boolean rescore = restRequest.paramAsBoolean("rescore", RESCORE_DEFAULT);
final Boolean clear_cache = restRequest.paramAsBoolean("clear_cache", false);
final Boolean debug = restRequest.paramAsBoolean("debug", false);

@SuppressWarnings("unchecked")
List<Double> queryVector = (List<Double>) aknnQueryMap.get(VECTOR_KEY);
stopWatch.stop();
// Clear LSH model cache if requested
if (clear_cache == true) {
// Clear LSH model cache
lshModelCache.remove(aknnURI);
}
// Check if the LshModel has been cached. If not, retrieve the Aknn document and use it to populate the model.
LshModel lshModel = InitLsh(aknnURI, client);

stopWatch.start("Query nearest neighbors");
@SuppressWarnings("unchecked")
Map<String, Long> queryHashes = lshModel.getVectorHashes(queryVector);
//logger.info("HASHES: {}", queryHashes);


List<Map<String, Object>> modifiedSortedHits = QueryLsh(queryVector, queryHashes, index, type, k1, rescore, filter, minimum_should_match, debug, client);
stopWatch.stop();
logger.info("Timing summary\n {}", stopWatch.prettyPrint());
return channel -> {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("took", stopWatch.totalTime().getMillis());
builder.field("timed_out", false);
builder.startObject("hits");
builder.field("max_score", 0);

// In some cases there will not be enough approximate matches to return *k2* hits. For example, this could
// be the case if the number of bits per table in the LSH model is too high, over-partioning the space.
builder.field("total", min(k2, modifiedSortedHits.size()));
builder.field("hits", modifiedSortedHits.subList(0, min(k2, modifiedSortedHits.size())));
builder.endObject();
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
};
}


private RestChannelConsumer handleCreateRequest(RestRequest restRequest, NodeClient client) throws IOException {

StopWatch stopWatch = new StopWatch("StopWatch to time create request");
logger.info("Parse request");
stopWatch.start("Parse request");

XContentParser xContentParser = XContentHelper.createParser(
restRequest.getXContentRegistry(), restRequest.content(), restRequest.getXContentType());
restRequest.getXContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
restRequest.content(),
restRequest.getXContentType());
Map<String, Object> contentMap = xContentParser.mapOrdered();
@SuppressWarnings("unchecked")
Map<String, Object> sourceMap = (Map<String, Object>) contentMap.get("_source");


final String _index = (String) contentMap.get("_index");
final String _type = (String) contentMap.get("_type");
final String _id = (String) contentMap.get("_id");
final String description = (String) sourceMap.get("_aknn_description");
final Integer nbTables = (Integer) sourceMap.get("_aknn_nb_tables");
final Integer nbBitsPerTable = (Integer) sourceMap.get("_aknn_nb_bits_per_table");
final Integer nbDimensions = (Integer) sourceMap.get("_aknn_nb_dimensions");
@SuppressWarnings("unchecked")
final List<List<Double>> vectorSample = (List<List<Double>>) contentMap.get("_aknn_vector_sample");
@SuppressWarnings("unchecked") final List<List<Double>> vectorSample = (List<List<Double>>) contentMap.get("_aknn_vector_sample");
stopWatch.stop();

logger.info("Fit LSH model from sample vectors");
Expand Down Expand Up @@ -363,27 +463,35 @@ private RestChannelConsumer handleIndexRequest(RestRequest restRequest, NodeClie
logger.info("Parse request parameters");
stopWatch.start("Parse request parameters");
XContentParser xContentParser = XContentHelper.createParser(
restRequest.getXContentRegistry(), restRequest.content(), restRequest.getXContentType());
restRequest.getXContentRegistry(),
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
restRequest.content(),
restRequest.getXContentType());
Map<String, Object> contentMap = xContentParser.mapOrdered();
final String index = (String) contentMap.get("_index");
final String type = (String) contentMap.get("_type");
final String aknnURI = (String) contentMap.get("_aknn_uri");
@SuppressWarnings("unchecked")
final List<Map<String, Object>> docs = (List<Map<String, Object>>) contentMap.get("_aknn_docs");
final Boolean clear_cache = restRequest.paramAsBoolean("clear_cache", false);
@SuppressWarnings("unchecked") final List<Map<String, Object>> docs = (List<Map<String, Object>>) contentMap.get("_aknn_docs");
logger.info("Received {} docs for indexing", docs.size());
stopWatch.stop();

// TODO: check if the index exists. If not, create a mapping which does not index continuous values.
// This is rather low priority, as I tried it via Python and it doesn't make much difference.

// Clear LSH model cache if requested
if (clear_cache == true) {
// Clear LSH model cache
lshModelCache.remove(aknnURI);
}
// Check if the LshModel has been cached. If not, retrieve the Aknn document and use it to populate the model.
LshModel lshModel = InitLsh(aknnURI, client);

// Prepare documents for batch indexing.
logger.info("Hash documents for indexing");
stopWatch.start("Hash documents for indexing");
BulkRequestBuilder bulkIndexRequest = client.prepareBulk();
for (Map<String, Object> doc: docs) {
for (Map<String, Object> doc : docs) {
@SuppressWarnings("unchecked")
Map<String, Object> source = (Map<String, Object>) doc.get("_source");
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -424,4 +532,27 @@ private RestChannelConsumer handleIndexRequest(RestRequest restRequest, NodeClie
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
};
}

private RestChannelConsumer handleClearRequest(RestRequest restRequest, NodeClient client) throws IOException {

//TODO: figure out how to execute clear cache on all nodes at once;

StopWatch stopWatch = new StopWatch("StopWatch to time clear cache");
logger.info("Clearing LSH models cache");
stopWatch.start("Clearing cache");
lshModelCache.clear();
stopWatch.stop();
logger.info("Timing summary\n {}", stopWatch.prettyPrint());


return channel -> {
XContentBuilder builder = channel.newBuilder();
builder.startObject();
builder.field("took", stopWatch.totalTime().getMillis());
builder.field("acknowledged", true);
builder.endObject();
channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder));
};
}

}

0 comments on commit 89ca31a

Please sign in to comment.