Skip to content

Commit

Permalink
Change query encoder naming convention to Python (#2182)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurChen189 authored Sep 3, 2023
1 parent f2edb89 commit 9f90a27
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions src/main/java/io/anserini/search/SimpleImpactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ public Map<String, Result[]> batch_search_queries(List<String> queries,
* @throws OrtException if errors encountered during encoding
* @return encoded query
*/
public Map<String, Integer> encodeWithOnnx(String queryString) throws OrtException {
public Map<String, Integer> encode_with_onnx(String queryString) throws OrtException {
// if no query encoder, assume its encoded query split by whitespace
if (this.queryEncoder == null){
List<String> queryTokens = AnalyzerUtils.analyze(analyzer, queryString);
Expand All @@ -570,7 +570,7 @@ public Map<String, Integer> encodeWithOnnx(String queryString) throws OrtExcepti
* @throws OrtException if errors encountered during encoding
* @return encoded query
*/
public String encodeWithOnnx(Map<String, Integer> queryWeight) throws OrtException {
public String encode_with_onnx(Map<String, Integer> queryWeight) throws OrtException {
String encodedQ = "";
List<String> encodedQuery = new ArrayList<>();
for (Map.Entry<String, Integer> entry : queryWeight.entrySet()) {
Expand Down Expand Up @@ -622,7 +622,7 @@ public Result[] search(String q) throws IOException, OrtException {
public Result[] search(Map<String, Integer> encoded_q, int k) throws IOException, OrtException {
Map<String, Float> float_encoded_q = intToFloat(encoded_q);
Query query = generator.buildQuery(Constants.CONTENTS, float_encoded_q);
String encodedQuery = encodeWithOnnx(encoded_q);
String encodedQuery = encode_with_onnx(encoded_q);
return _search(query, encodedQuery, k);
}

Expand All @@ -637,8 +637,8 @@ public Result[] search(Map<String, Integer> encoded_q, int k) throws IOException
*/
public Result[] search(String q, int k) throws IOException, OrtException {
// make encoded query from raw query
Map<String, Integer> encoded_q = encodeWithOnnx(q);
String encodedQuery = encodeWithOnnx(encoded_q);
Map<String, Integer> encoded_q = encode_with_onnx(q);
String encodedQuery = encode_with_onnx(encoded_q);
Query query = generator.buildQuery(Constants.CONTENTS, analyzer, encodedQuery);
return _search(query, encodedQuery, k);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ public void testOnnxEncodedQuery() throws Exception {
SimpleImpactSearcher searcher = new SimpleImpactSearcher(super.tempDir1.toString());
Map<String, Integer> testQuery1 = new HashMap<>();
testQuery1.put("text", 2);
String encodedQuery = searcher.encodeWithOnnx(testQuery1);
String encodedQuery = searcher.encode_with_onnx(testQuery1);
assertEquals("text text" ,encodedQuery);
}

Expand All @@ -229,7 +229,7 @@ public void testOnnxEncoder() throws Exception{
SimpleImpactSearcher searcher = new SimpleImpactSearcher();
searcher.set_onnx_query_encoder("SpladePlusPlusEnsembleDistil");

Map<String, Integer> encoded_query = searcher.encodeWithOnnx("here is a test");
Map<String, Integer> encoded_query = searcher.encode_with_onnx("here is a test");
assertEquals(encoded_query.get("here"), EXPECTED_ENCODED_QUERY.get("here"), 2e-4);
assertEquals(encoded_query.get("a"), EXPECTED_ENCODED_QUERY.get("a"), 2e-4);
assertEquals(encoded_query.get("test"), EXPECTED_ENCODED_QUERY.get("test"), 2e-4);
Expand Down

0 comments on commit 9f90a27

Please sign in to comment.