Skip to content

Commit

Permalink
Add tests for fetch phase
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <[email protected]>
  • Loading branch information
rishabhmaurya committed Apr 5, 2024
1 parent ef24c12 commit abbc52d
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,27 +83,27 @@ public DerivedFieldValueFetcher valueFetcher(QueryShardContext context, SearchLo
if (format != null) {
throw new IllegalArgumentException("Field [" + name() + "] of type [" + typeName() + "] doesn't support formats.");
}
return new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
return new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context, searchLookup));
}

@Override
public Query termQuery(Object value, QueryShardContext context) {
Query query = typeFieldMapper.mappedFieldType.termQuery(value, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

@Override
public Query termQueryCaseInsensitive(Object value, @Nullable QueryShardContext context) {
Query query = typeFieldMapper.mappedFieldType.termQueryCaseInsensitive(value, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

@Override
public Query termsQuery(List<?> values, @Nullable QueryShardContext context) {
Query query = typeFieldMapper.mappedFieldType.termsQuery(values, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -128,7 +128,7 @@ public Query rangeQuery(
parser,
context
);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -142,7 +142,7 @@ public Query fuzzyQuery(
QueryShardContext context
) {
Query query = typeFieldMapper.mappedFieldType.fuzzyQuery(value, fuzziness, prefixLength, maxExpansions, transpositions, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -165,7 +165,7 @@ public Query fuzzyQuery(
method,
context
);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -177,7 +177,7 @@ public Query prefixQuery(
QueryShardContext context
) {
Query query = typeFieldMapper.mappedFieldType.prefixQuery(value, method, caseInsensitive, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -189,14 +189,14 @@ public Query wildcardQuery(
QueryShardContext context
) {
Query query = typeFieldMapper.mappedFieldType.wildcardQuery(value, method, caseInsensitive, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

@Override
public Query normalizedWildcardQuery(String value, @Nullable MultiTermQuery.RewriteMethod method, QueryShardContext context) {
Query query = typeFieldMapper.mappedFieldType.normalizedWildcardQuery(value, method, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -210,29 +210,29 @@ public Query regexpQuery(
QueryShardContext context
) {
Query query = typeFieldMapper.mappedFieldType.regexpQuery(value, syntaxFlags, matchFlags, maxDeterminizedStates, method, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

@Override
public Query phraseQuery(TokenStream stream, int slop, boolean enablePositionIncrements, QueryShardContext context) throws IOException {
Query query = typeFieldMapper.mappedFieldType.phraseQuery(stream, slop, enablePositionIncrements, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

@Override
public Query multiPhraseQuery(TokenStream stream, int slop, boolean enablePositionIncrements, QueryShardContext context)
throws IOException {
Query query = typeFieldMapper.mappedFieldType.multiPhraseQuery(stream, slop, enablePositionIncrements, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

@Override
public Query phrasePrefixQuery(TokenStream stream, int slop, int maxExpansions, QueryShardContext context) throws IOException {
Query query = typeFieldMapper.mappedFieldType.phrasePrefixQuery(stream, slop, maxExpansions, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -246,7 +246,7 @@ public SpanQuery spanPrefixQuery(String value, SpanMultiTermQueryWrapper.SpanRew
@Override
public Query distanceFeatureQuery(Object origin, String pivot, float boost, QueryShardContext context) {
Query query = typeFieldMapper.mappedFieldType.distanceFeatureQuery(origin, pivot, boost, context);
DerivedFieldValueFetcher valueFetcher = new DerivedFieldValueFetcher(getDerivedFieldLeafFactory(context));
DerivedFieldValueFetcher valueFetcher = valueFetcher(context, context.lookup(), null);
return new DerivedFieldQuery(query, valueFetcher, context.lookup(), indexableFieldGenerator, getIndexAnalyzer());
}

Expand All @@ -260,7 +260,7 @@ public boolean isAggregatable() {
return false;
}

private DerivedFieldScript.LeafFactory getDerivedFieldLeafFactory(QueryShardContext context) {
private DerivedFieldScript.LeafFactory getDerivedFieldLeafFactory(QueryShardContext context, SearchLookup searchLookup) {
if (!context.documentMapper("").sourceMapper().enabled()) {
throw new IllegalArgumentException(
"DerivedFieldQuery error: unable to fetch fields from _source field: _source is disabled in the mappings "
Expand All @@ -270,6 +270,6 @@ private DerivedFieldScript.LeafFactory getDerivedFieldLeafFactory(QueryShardCont
);
}
DerivedFieldScript.Factory factory = context.compile(derivedField.getScript(), DerivedFieldScript.CONTEXT);
return factory.newFactory(derivedField.getScript().getParams(), context.lookup());
return factory.newFactory(derivedField.getScript().getParams(), searchLookup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public List<Object> fetchValues(SourceLookup lookup) {
return derivedFieldScript.getEmittedValues();
}

@Override
public void setNextReader(LeafReaderContext context) {
try {
derivedFieldScript = derivedFieldScriptFactory.newInstance(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,6 @@ public DerivedFieldScript(Map<String, Object> params, SearchLookup lookup, LeafR
this.totalByteSize = 0;
}

public DerivedFieldScript() {
this.params = null;
this.leafLookup = null;
this.emittedValues = new ArrayList<>();
this.totalByteSize = 0;
}

/**
* Return the parameters for this script.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,34 @@
import org.opensearch.index.IndexSettings;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.script.MockScriptEngine;
import org.opensearch.script.ScriptEngine;
import org.opensearch.script.ScriptModule;
import org.opensearch.script.ScriptService;
import org.opensearch.search.lookup.LeafSearchLookup;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.lookup.SourceLookup;
import org.opensearch.test.OpenSearchSingleNodeTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static java.util.Collections.singletonMap;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItems;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class FieldFetcherTests extends OpenSearchSingleNodeTestCase {

private static String DERIVED_FIELD_SCRIPT_1 = "derived_field_script_1";
private static String DERIVED_FIELD_SCRIPT_2 = "derived_field_script_2";

public void testLeafValues() throws IOException {
MapperService mapperService = createMapperService();
XContentBuilder source = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -435,6 +449,45 @@ public void testTextSubFields() throws IOException {
}
}

public void testDerivedFields() throws IOException {
XContentBuilder mapping = XContentFactory.jsonBuilder()
.startObject()
.startObject("derived")
.startObject("derived_1")
.field("type", "keyword")
.startObject("script")
.field("source", DERIVED_FIELD_SCRIPT_1)
.field("lang", "mockscript")
.endObject()
.endObject()
.startObject("derived_2")
.field("type", "keyword")
.startObject("script")
.field("source", DERIVED_FIELD_SCRIPT_2)
.field("lang", "mockscript")
.endObject()
.endObject()
.endObject()
.endObject();

IndexService indexService = createIndex("index", Settings.EMPTY, MapperService.SINGLE_MAPPING_NAME, mapping);
MapperService mapperService = indexService.mapperService();

XContentBuilder source = XContentFactory.jsonBuilder()
.startObject()
.field("field1", "some text 1")
.field("field2", "some text 2")
.endObject();

Map<String, DocumentField> fields = fetchFields(mapperService, source, "*");
assertThat(fields.size(), equalTo(2));
assertThat(fields.keySet(), containsInAnyOrder("derived_1", "derived_2"));
assertThat(fields.get("derived_1").getValues().size(), equalTo(1));
assertThat(fields.get("derived_2").getValues().size(), equalTo(1));
assertThat(fields.get("derived_1").getValue(), equalTo("some text 1"));
assertThat(fields.get("derived_2").getValue(), equalTo("some text 2"));
}

private static Map<String, DocumentField> fetchFields(MapperService mapperService, XContentBuilder source, String fieldPattern)
throws IOException {

Expand All @@ -448,7 +501,13 @@ private static Map<String, DocumentField> fetchFields(MapperService mapperServic
SourceLookup sourceLookup = new SourceLookup();
sourceLookup.setSource(BytesReference.bytes(source));

FieldFetcher fieldFetcher = FieldFetcher.create(createQueryShardContext(mapperService), null, fields);
SearchLookup searchLookup = mock(SearchLookup.class);
LeafSearchLookup leafSearchLookup = mock(LeafSearchLookup.class);
when(searchLookup.source()).thenReturn(sourceLookup);
when(searchLookup.getLeafSearchLookup(any())).thenReturn(leafSearchLookup);
when(leafSearchLookup.source()).thenReturn(sourceLookup);
FieldFetcher fieldFetcher = FieldFetcher.create(createQueryShardContext(mapperService), searchLookup, fields);
fieldFetcher.setNextReader(null);
return fieldFetcher.fetch(sourceLookup, Set.of());
}

Expand Down Expand Up @@ -497,6 +556,19 @@ private static QueryShardContext createQueryShardContext(MapperService mapperSer
.build();
IndexMetadata indexMetadata = new IndexMetadata.Builder("index").settings(settings).build();
IndexSettings indexSettings = new IndexSettings(indexMetadata, settings);

final MockScriptEngine engine = new MockScriptEngine(
MockScriptEngine.NAME,
Map.of(
DERIVED_FIELD_SCRIPT_1,
(script) -> ((Map<String, Object>) script.get("_source")).get("field1"),
DERIVED_FIELD_SCRIPT_2,
(script) -> ((Map<String, Object>) script.get("_source")).get("field2")
),
Collections.emptyMap()
);
final Map<String, ScriptEngine> engines = singletonMap(engine.getType(), engine);
ScriptService scriptService = new ScriptService(Settings.EMPTY, engines, ScriptModule.CORE_CONTEXTS);
return new QueryShardContext(
0,
indexSettings,
Expand All @@ -505,7 +577,7 @@ private static QueryShardContext createQueryShardContext(MapperService mapperSer
null,
mapperService,
null,
null,
scriptService,
null,
null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.opensearch.search.aggregations.pipeline.MovingFunctionScript;
import org.opensearch.search.lookup.LeafSearchLookup;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.lookup.SourceLookup;

import java.io.IOException;
import java.util.Collections;
Expand Down Expand Up @@ -282,16 +283,22 @@ public double execute(Map<String, Object> params1, double[] values) {
IntervalFilterScript.Factory factory = mockCompiled::createIntervalFilterScript;
return context.factoryClazz.cast(factory);
} else if (context.instanceClazz.equals(DerivedFieldScript.class)) {
DerivedFieldScript.Factory factory = (derivedFieldsParams, lookup) -> ctx -> new DerivedFieldScript(
derivedFieldsParams,
DerivedFieldScript.Factory factory = (derivedFieldParams, lookup) -> ctx -> new DerivedFieldScript(
derivedFieldParams,
lookup,
ctx
) {
@Override
public void setDocument(int docid) {}

@Override
public void execute() {
Map<String, Object> vars = new HashMap<>(derivedFieldsParams);
vars.put("params", derivedFieldsParams);
script.apply(vars);
Map<String, Object> vars = new HashMap<>(derivedFieldParams);
SourceLookup sourceLookup = lookup.source();
vars.put("params", derivedFieldParams);
vars.put("_source", sourceLookup.loadSourceIfNeeded());
// currently supports adding one value, can be extended to emit multiple values too.
addEmittedValue(script.apply(vars));
}
};
return context.factoryClazz.cast(factory);
Expand Down

0 comments on commit abbc52d

Please sign in to comment.