Skip to content

Commit

Permalink
Add SplitResponseProcessor for search pipelines
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Jul 17, 2024
1 parent a2cef8f commit 63a721f
Show file tree
Hide file tree
Showing 2 changed files with 348 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Processor that sorts an array of items.
* Throws exception is the specified field is not an array.
*/
public class SplitResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {
/** Key to reference this processor type from a search pipeline. */
public static final String TYPE = "split";
/** Key defining the string field to be split. */
public static final String SPLIT_FIELD = "field";
/** Key defining the delimiter used to split the string. This can be a regular expression pattern. */
public static final String SEPARATOR = "separator";
/** Optional key for handling empty trailing fields. */
public static final String PRESERVE_TRAILING = "preserve_trailing";
/** Optional key to put the split values in a different field. */
public static final String TARGET_FIELD = "target_field";

private final String splitField;
private final String separator;
private final boolean preserveTrailing;
private final String targetField;

SplitResponseProcessor(
String tag,
String description,
boolean ignoreFailure,
String splitField,
String separator,
boolean preserveTrailing,
String targetField
) {
super(tag, description, ignoreFailure);
this.splitField = Objects.requireNonNull(splitField);
this.separator = Objects.requireNonNull(separator);
this.preserveTrailing = preserveTrailing;
this.targetField = targetField == null ? splitField : targetField;
}

/**
* Getter function for splitField
* @return sortField
*/
public String getSplitField() {
return splitField;
}

/**
* Getter function for separator
* @return separator
*/
public String getSeparator() {
return separator;
}

/**
* Getter function for preserveTrailing
* @return preserveTrailing;
*/
public boolean isPreserveTrailing() {
return preserveTrailing;
}

/**
* Getter function for targetField
* @return targetField
*/
public String getTargetField() {
return targetField;
}

@Override
public String getType() {
return TYPE;
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
SearchHit[] hits = response.getHits().getHits();
for (SearchHit hit : hits) {
Map<String, DocumentField> fields = hit.getFields();
if (fields.containsKey(splitField)) {
DocumentField docField = hit.getFields().get(splitField);
if (docField == null) {
throw new IllegalArgumentException("field [" + splitField + "] is null, cannot split.");
}
Object val = docField.getValue();
if (val == null || !String.class.isAssignableFrom(val.getClass())) {
throw new IllegalArgumentException("field [" + splitField + "] is not a string, cannot split");
}
String[] strings = ((String) val).split(separator, preserveTrailing ? -1 : 0);
List<Object> splitList = Stream.of(strings).collect(Collectors.toList());
hit.setDocumentField(targetField, new DocumentField(targetField, splitList));
}
if (hit.hasSource()) {
BytesReference sourceRef = hit.getSourceRef();
Tuple<? extends MediaType, Map<String, Object>> typeAndSourceMap = XContentHelper.convertToMap(
sourceRef,
false,
(MediaType) null
);

Map<String, Object> sourceAsMap = typeAndSourceMap.v2();
if (sourceAsMap.containsKey(splitField)) {
Object val = sourceAsMap.get(splitField);
if (val instanceof String) {
String[] strings = ((String) val).split(separator, preserveTrailing ? -1 : 0);
List<Object> splitList = Stream.of(strings).collect(Collectors.toList());
sourceAsMap.put(targetField, splitList);
}
XContentBuilder builder = XContentBuilder.builder(typeAndSourceMap.v1().xContent());
builder.map(sourceAsMap);
hit.sourceRef(BytesReference.bytes(builder));
}
}
}
return response;
}

static class Factory implements Processor.Factory<SearchResponseProcessor> {

@Override
public SplitResponseProcessor create(
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) {
String splitField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "field");
String separator = ConfigurationUtils.readStringProperty(TYPE, tag, config, "separator");
boolean preserveTrailing = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, "preserve_trailing", false);
String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "target_field", splitField);
return new SplitResponseProcessor(tag, description, ignoreFailure, splitField, separator, preserveTrailing, targetField);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a.java
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.apache.lucene.search.TotalHits;
import org.opensearch.OpenSearchParseException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.common.document.DocumentField;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ingest.RandomDocumentPicks;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SplitResponseProcessorTests extends OpenSearchTestCase {

private static final String NO_TRAILING = "one,two,three";
private static final String TRAILING = "alpha,beta,gamma,";

private SearchRequest createDummyRequest() {
QueryBuilder query = new TermQueryBuilder("field", "value");
SearchSourceBuilder source = new SearchSourceBuilder().query(query);
return new SearchRequest().source(source);
}

private SearchResponse createTestResponse() {
SearchHit[] hits = new SearchHit[2];

// one response with source
Map<String, DocumentField> csvMap = new HashMap<>();
csvMap.put("csv", new DocumentField("csv", List.of(NO_TRAILING)));
hits[0] = new SearchHit(0, "doc 1", csvMap, Collections.emptyMap());
hits[0].sourceRef(new BytesArray("{ \"csv\" : \"" + NO_TRAILING + "\" }"));
hits[0].score(1f);

// one without source
csvMap = new HashMap<>();
csvMap.put("csv", new DocumentField("csv", List.of(TRAILING)));
hits[1] = new SearchHit(1, "doc 2", csvMap, Collections.emptyMap());
hits[1].score(2f);

SearchHits searchHits = new SearchHits(hits, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 2);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);
}

private SearchResponse createTestResponseNullField() {
SearchHit[] hits = new SearchHit[1];

Map<String, DocumentField> map = new HashMap<>();
map.put("csv", null);
hits[0] = new SearchHit(0, "doc 1", map, Collections.emptyMap());
hits[0].sourceRef(new BytesArray("{ \"csv\" : null }"));
hits[0].score(1f);

SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);
}

private SearchResponse createTestResponseEmptyList() {
SearchHit[] hits = new SearchHit[1];

Map<String, DocumentField> map = new HashMap<>();
map.put("empty", new DocumentField("empty", List.of()));
hits[0] = new SearchHit(0, "doc 1", map, Collections.emptyMap());
hits[0].sourceRef(new BytesArray("{ \"empty\" : [] }"));
hits[0].score(1f);

SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);
}

private SearchResponse createTestResponseNotString() {
SearchHit[] hits = new SearchHit[1];

Map<String, DocumentField> piMap = new HashMap<>();
piMap.put("maps", new DocumentField("maps", List.of(Map.of("foo", "I'm the Map!"))));
hits[0] = new SearchHit(0, "doc 1", piMap, Collections.emptyMap());
hits[0].sourceRef(new BytesArray("{ \"maps\" : [{ \"foo\" : \"I'm the Map!\"}]] }"));
hits[0].score(1f);

SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1);
SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
return new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null);
}

public void testSplitResponse() throws Exception {
SearchRequest request = createDummyRequest();

SplitResponseProcessor splitResponseProcessor = new SplitResponseProcessor(null, null, false, "csv", ",", false, "split");
SearchResponse response = createTestResponse();
SearchResponse splitResponse = splitResponseProcessor.processResponse(request, response);

assertEquals(response.getHits(), splitResponse.getHits());

assertEquals(NO_TRAILING, splitResponse.getHits().getHits()[0].field("csv").getValue());
assertEquals(List.of("one", "two", "three"), splitResponse.getHits().getHits()[0].field("split").getValues());
Map<String, Object> map = splitResponse.getHits().getHits()[0].getSourceAsMap();
assertNotNull(map);
assertEquals(List.of("one", "two", "three"), map.get("split"));

assertEquals(TRAILING, splitResponse.getHits().getHits()[1].field("csv").getValue());
assertEquals(List.of("alpha", "beta", "gamma"), splitResponse.getHits().getHits()[1].field("split").getValues());
assertNull(splitResponse.getHits().getHits()[1].getSourceAsMap());
}

public void testSplitResponseSameField() throws Exception {
SearchRequest request = createDummyRequest();

SplitResponseProcessor splitResponseProcessor = new SplitResponseProcessor(null, null, false, "csv", ",", true, null);
SearchResponse response = createTestResponse();
SearchResponse splitResponse = splitResponseProcessor.processResponse(request, response);

assertEquals(response.getHits(), splitResponse.getHits());
assertEquals(List.of("one", "two", "three"), splitResponse.getHits().getHits()[0].field("csv").getValues());
assertEquals(List.of("alpha", "beta", "gamma", ""), splitResponse.getHits().getHits()[1].field("csv").getValues());
}

public void testSplitResponseEmptyList() {
SearchRequest request = createDummyRequest();

SplitResponseProcessor splitResponseProcessor = new SplitResponseProcessor(null, null, false, "empty", ",", false, null);
assertThrows(IllegalArgumentException.class, () -> splitResponseProcessor.processResponse(request, createTestResponseEmptyList()));
}

public void testNullField() {
SearchRequest request = createDummyRequest();

SplitResponseProcessor splitResponseProcessor = new SplitResponseProcessor(null, null, false, "csv", ",", false, null);

assertThrows(IllegalArgumentException.class, () -> splitResponseProcessor.processResponse(request, createTestResponseNullField()));
}

public void testNotStringField() {
SearchRequest request = createDummyRequest();

SplitResponseProcessor splitResponseProcessor = new SplitResponseProcessor(null, null, false, "maps", ",", false, null);

assertThrows(IllegalArgumentException.class, () -> splitResponseProcessor.processResponse(request, createTestResponseNotString()));
}

public void testFactory() {
String splitField = RandomDocumentPicks.randomFieldName(random());
String targetField = RandomDocumentPicks.randomFieldName(random());
Map<String, Object> config = new HashMap<>();
config.put("field", splitField);
config.put("separator", ",");
config.put("preserve_trailing", true);
config.put("target_field", targetField);

SplitResponseProcessor.Factory factory = new SplitResponseProcessor.Factory();
SplitResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null);
assertEquals("split", processor.getType());
assertEquals(splitField, processor.getSplitField());
assertEquals(",", processor.getSeparator());
assertTrue(processor.isPreserveTrailing());
assertEquals(targetField, processor.getTargetField());

expectThrows(
OpenSearchParseException.class,
() -> factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null)
);
}
}

0 comments on commit 63a721f

Please sign in to comment.