From d3d7199681735b66ebe5e805c1538eba519e133d Mon Sep 17 00:00:00 2001
From: Richard Goodman
Date: Fri, 3 Jul 2020 12:07:56 +0100
Subject: [PATCH 1/4] Add SOLR-14241 delete streaming support
---
.../src/stream-decorator-reference.adoc | 44 ++++-
.../org/apache/solr/client/solrj/io/Lang.java | 3 +-
.../client/solrj/io/stream/DeleteStream.java | 112 +++++++++++++
.../client/solrj/io/stream/UpdateStream.java | 50 +++++-
.../configsets/streaming/conf/solrconfig.xml | 5 +
.../apache/solr/client/solrj/io/TestLang.java | 2 +-
.../solrj/io/stream/StreamDecoratorTest.java | 155 +++++++++++++++++-
7 files changed, 361 insertions(+), 10 deletions(-)
create mode 100644 solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DeleteStream.java
diff --git a/solr/solr-ref-guide/src/stream-decorator-reference.adoc b/solr/solr-ref-guide/src/stream-decorator-reference.adoc
index 1465b83b807a..c99cd971af5e 100644
--- a/solr/solr-ref-guide/src/stream-decorator-reference.adoc
+++ b/solr/solr-ref-guide/src/stream-decorator-reference.adoc
@@ -595,6 +595,45 @@ while(true) {
daemonStream.close();
----
+== delete
+
+The `delete` function wraps another functions and uses the `id` and `\_version_` values found to sends the tuples to a SolrCloud collection as <> commands.
+
+This is similar to the `<<#update,update()>>` function described below.
+
+=== delete Parameters
+
+* `destinationCollection`: (Mandatory) The collection where the tuples will deleted.
+* `batchSize`: (Mandatory) The indexing batch size.
+* `pruneVersionField`: (Optional, defaults to `false`) Wether to prune `\_version_` values from tuples
+* `StreamExpression`: (Mandatory)
+
+=== delete Syntax
+
+[source,text]
+----
+ delete(collection1
+ batchSize=500,
+ search(collection1,
+ q=old_data:true,
+ qt="/export",
+ fl="id",
+ sort="a_f asc, a_i asc"))
+
+----
+
+The example above consumes the tuples returned by the `search` function against `collection1` and converts the `id` value of each document found into a delete request against the same `collection1`.
+
+[NOTE]
+====
+Unlike the `update()` function, `delete()` defaults to `pruneVersionField=false` -- preserving any `\_version_` values found in the inner stream when converting the tuples to "Delete By ID" requests, to ensure that using this stream will not (by default) result in deleting any documents that were updated _after_ the `search(...)` was executed, but _before_ the `delete(...)` processed that tuple (leveraging <> constraints).
+
+Users who wish to ignore concurrent updates, and delete all matched documents should set `pruneVersionField=false` (or ensure that the inner stream tuples do not include any `\_version_` values).
+
+Users who anticipate concurrent updates, and wish to "skip" any failed deletes, should consider configuraing the {solr-javadocs}/solr-core/org/apache/solr/update/processor/TolerantUpdateProcessorFactory.html[`TolerantUpdateProcessorFactory`]
+====
+
+
== eval
The `eval` function allows for use cases where new streaming expressions are generated on the fly and then evaluated.
@@ -1273,12 +1312,13 @@ unique(
== update
-The `update` function wraps another functions and sends the tuples to a SolrCloud collection for indexing.
+The `update` function wraps another functions and sends the tuples to a SolrCloud collection for indexing as Documents.
=== update Parameters
* `destinationCollection`: (Mandatory) The collection where the tuples will indexed.
* `batchSize`: (Mandatory) The indexing batch size.
+* `pruneVersionField`: (Optional, defaults to `true`) Wether to prune `\_version_` values from tuples
* `StreamExpression`: (Mandatory)
=== update Syntax
@@ -1296,3 +1336,5 @@ The `update` function wraps another functions and sends the tuples to a SolrClou
----
The example above sends the tuples returned by the `search` function to the `destinationCollection` to be indexed.
+
+Wrapping `search(...)` as showing in this example is the common case usage of this decorator: to read documents from a collection as tuples, process or modify them in some way, and then add them back to a new collection. For this reason, `pruneVersionField=true` is the default behavior -- stripping any `\_version_` values found in the inner stream when converting the tuples to Solr documents to prevent any unexpected errors from <> constraints.
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
index a1a796d1ecfe..9c3d71c804a9 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
@@ -40,6 +40,7 @@ public static void register(StreamFactory streamFactory) {
.withFunctionName("facet", FacetStream.class)
.withFunctionName("update", UpdateStream.class)
.withFunctionName("jdbc", JDBCStream.class)
+ .withFunctionName("delete", DeleteStream.class)
.withFunctionName("topic", TopicStream.class)
.withFunctionName("commit", CommitStream.class)
.withFunctionName("random", RandomStream.class)
@@ -336,4 +337,4 @@ public static void register(StreamFactory streamFactory) {
.withFunctionName("if", IfThenElseEvaluator.class)
.withFunctionName("convert", ConversionEvaluator.class);
}
-}
\ No newline at end of file
+}
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DeleteStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DeleteStream.java
new file mode 100644
index 000000000000..929db8ecba60
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/DeleteStream.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.client.solrj.io.stream;
+
+import java.io.IOException;
+import java.lang.invoke.MethodHandles;
+import java.util.List;
+import java.util.Locale;
+
+import org.apache.solr.client.solrj.SolrServerException;
+import org.apache.solr.client.solrj.request.UpdateRequest;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.stream.expr.Explanation;
+import org.apache.solr.client.solrj.io.stream.expr.Expressible;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.common.SolrInputDocument;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.solr.common.params.CommonParams.VERSION_FIELD;
+
+/**
+ * Uses tuples to identify the uniqueKey values of documents to be deleted
+ */
+public final class DeleteStream extends UpdateStream implements Expressible {
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ private static final String ID_TUPLE_KEY = "id";
+
+ public DeleteStream(StreamExpression expression, StreamFactory factory) throws IOException {
+ super(expression, factory);
+ }
+
+ @Override
+ public Explanation toExplanation(StreamFactory factory) throws IOException {
+ final Explanation explanation = super.toExplanation(factory);
+ explanation.setExpression("Delete docs from " + getCollectionName());
+
+ return explanation;
+ }
+
+ /**
+ * {@link DeleteStream} returns false so that Optimistic Concurrency Constraints are
+ * respected by default when using this stream to wrap a {@link SearchStream} query.
+ */
+ @Override
+ protected boolean defaultPruneVersionField() {
+ return false;
+ }
+
+ /**
+ * Overrides implementation to extract the "id" and "_version_"
+ * (if included) from each document and use that information to construct a "Delete By Id" request.
+ * Any other fields (ie: Tuple values) are ignored.
+ */
+ @Override
+ protected void uploadBatchToCollection(List documentBatch) throws IOException {
+ if (documentBatch.size() == 0) {
+ return;
+ }
+
+ try {
+ // convert each doc into a deleteById request...
+ final UpdateRequest req = new UpdateRequest();
+ for (SolrInputDocument doc : documentBatch) {
+ final String id = doc.getFieldValue(ID_TUPLE_KEY).toString();
+ final Long version = getVersion(doc);
+ req.deleteById(id, version);
+ }
+ req.process(getCloudSolrClient(), getCollectionName());
+ } catch (SolrServerException | NumberFormatException| IOException e) {
+ log.warn("Unable to delete documents from collection due to unexpected error.", e);
+ String className = e.getClass().getName();
+ String message = e.getMessage();
+ throw new IOException(String.format(Locale.ROOT,"Unexpected error when deleting documents from collection %s- %s:%s", getCollectionName(), className, message));
+ }
+ }
+
+ /**
+ * Helper method that can handle String values when dealing with odd
+ * {@link Tuple} -> {@link SolrInputDocument} conversions
+ * (ie: tuple(..) in tests)
+ */
+ private static Long getVersion(final SolrInputDocument doc) throws NumberFormatException {
+ if (! doc.containsKey(VERSION_FIELD)) {
+ return null;
+ }
+ final Object v = doc.getFieldValue(VERSION_FIELD);
+ if (null == v) {
+ return null;
+ }
+ if (v instanceof Long) {
+ return (Long)v;
+ }
+ return Long.parseLong(v.toString());
+ }
+}
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/UpdateStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/UpdateStream.java
index c00de1020c5b..94a2151b0e1a 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/UpdateStream.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/UpdateStream.java
@@ -40,10 +40,10 @@
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.common.SolrInputDocument;
+import org.apache.solr.common.params.CommonParams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import static org.apache.solr.common.params.CommonParams.VERSION_FIELD;
/**
* Sends tuples emitted by a wrapped {@link TupleStream} as updates to a SolrCloud collection.
@@ -56,6 +56,13 @@ public class UpdateStream extends TupleStream implements Expressible {
private String collection;
private String zkHost;
private int updateBatchSize;
+ /**
+ * Indicates if the {@link CommonParams#VERSION_FIELD} should be removed from tuples when converting
+ * to Solr Documents.
+ * May be set per expression using the "pruneVersionField" named operand,
+ * defaults to the value returned by {@link #defaultPruneVersionField()}
+ */
+ private boolean pruneVersionField;
private int batchNumber;
private long totalDocsIndex;
private PushBackStream tupleSource;
@@ -64,7 +71,6 @@ public class UpdateStream extends TupleStream implements Expressible {
private List documentBatch = new ArrayList();
private String coreName;
-
public UpdateStream(StreamExpression expression, StreamFactory factory) throws IOException {
String collectionName = factory.getValueOperand(expression, 0);
verifyCollectionName(collectionName, expression);
@@ -73,6 +79,7 @@ public UpdateStream(StreamExpression expression, StreamFactory factory) throws I
verifyZkHost(zkHost, collectionName, expression);
int updateBatchSize = extractBatchSize(expression, factory);
+ pruneVersionField = factory.getBooleanOperand(expression, "pruneVersionField", defaultPruneVersionField());
//Extract underlying TupleStream.
List streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
@@ -80,7 +87,6 @@ public UpdateStream(StreamExpression expression, StreamFactory factory) throws I
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting a single stream but found %d",expression, streamExpressions.size()));
}
StreamExpression sourceStreamExpression = streamExpressions.get(0);
-
init(collectionName, factory.constructStream(sourceStreamExpression), zkHost, updateBatchSize);
}
@@ -88,9 +94,10 @@ public UpdateStream(String collectionName, TupleStream tupleSource, String zkHos
if (updateBatchSize <= 0) {
throw new IOException(String.format(Locale.ROOT,"batchSize '%d' must be greater than 0.", updateBatchSize));
}
+ pruneVersionField = defaultPruneVersionField();
init(collectionName, tupleSource, zkHost, updateBatchSize);
}
-
+
private void init(String collectionName, TupleStream tupleSource, String zkHost, int updateBatchSize) {
this.collection = collectionName;
this.zkHost = zkHost;
@@ -98,6 +105,11 @@ private void init(String collectionName, TupleStream tupleSource, String zkHost,
this.tupleSource = new PushBackStream(tupleSource);
}
+ /** The name of the collection being updated */
+ protected String getCollectionName() {
+ return collection;
+ }
+
@Override
public void open() throws IOException {
setCloudSolrClient();
@@ -257,6 +269,21 @@ private int parseBatchSize(String batchSizeStr, StreamExpression expression) thr
throw new IOException(String.format(Locale.ROOT,"invalid expression %s - batchSize '%s' is not a valid integer.",expression, batchSizeStr));
}
}
+
+ /**
+ * Used during initialization to specify the default value for the "pruneVersionField" option.
+ * {@link UpdateStream} returns true for backcompat and to simplify slurping of data from one
+ * collection to another.
+ */
+ protected boolean defaultPruneVersionField() {
+ return true;
+ }
+
+ /** Only viable after calling {@link #open} */
+ protected CloudSolrClient getCloudSolrClient() {
+ assert null != this.cloudSolrClient;
+ return this.cloudSolrClient;
+ }
private void setCloudSolrClient() {
if(this.cache != null) {
@@ -272,7 +299,8 @@ private void setCloudSolrClient() {
private SolrInputDocument convertTupleToSolrDocument(Tuple tuple) {
SolrInputDocument doc = new SolrInputDocument();
for (Object field : tuple.fields.keySet()) {
- if (! field.equals(VERSION_FIELD)) {
+
+ if (! (field.equals(CommonParams.VERSION_FIELD) && pruneVersionField)) {
Object value = tuple.get(field);
if (value instanceof List) {
addMultivaluedField(doc, (String)field, (List
- *
- * If the Log level is not >= INFO the processor will not be created or added to the chain.
- *
*
* @since solr 1.3
*/
@@ -62,7 +59,7 @@ public void init( final NamedList args ) {
@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
- return log.isInfoEnabled() ? new LogUpdateProcessor(req, rsp, this, next) : null;
+ return new LogUpdateProcessor(req, rsp, this, next);
}
static class LogUpdateProcessor extends UpdateRequestProcessor {
@@ -185,6 +182,8 @@ public void finish() throws IOException {
if (log.isInfoEnabled()) {
log.info(getLogStringAndClearRspToLog());
+ } else {
+ rsp.getToLog().clear();
}
if (log.isWarnEnabled() && slowUpdateThresholdMillis >= 0) {
From 4885310f3eabbccd8a04beaea75c8490f713ff65 Mon Sep 17 00:00:00 2001
From: Matthew Kavanagh
Date: Wed, 2 Sep 2020 13:34:01 +0100
Subject: [PATCH 4/4] hashmap based freq-of-freq agg for str vals
---
.../apache/solr/search/ValueSourceParser.java | 3 +
.../search/facet/TermFrequencyCounter.java | 41 ++++++
.../facet/TermFrequencyOfFrequenciesAgg.java | 74 ++++++++++
.../search/facet/TermFrequencySlotAcc.java | 52 +++++++
.../facet/TermFrequencyCounterTest.java | 139 ++++++++++++++++++
5 files changed, 309 insertions(+)
create mode 100644 solr/core/src/java/org/apache/solr/search/facet/TermFrequencyCounter.java
create mode 100644 solr/core/src/java/org/apache/solr/search/facet/TermFrequencyOfFrequenciesAgg.java
create mode 100644 solr/core/src/java/org/apache/solr/search/facet/TermFrequencySlotAcc.java
create mode 100755 solr/core/src/test/org/apache/solr/search/facet/TermFrequencyCounterTest.java
diff --git a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
index e013e9eeeab8..b3f83fefbf06 100644
--- a/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
+++ b/solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
@@ -64,6 +64,7 @@
import org.apache.solr.search.facet.SumAgg;
import org.apache.solr.search.facet.SumsqAgg;
import org.apache.solr.search.facet.RelatednessAgg;
+import org.apache.solr.search.facet.TermFrequencyOfFrequenciesAgg;
import org.apache.solr.search.facet.TopDocsAgg;
import org.apache.solr.search.facet.UniqueAgg;
import org.apache.solr.search.facet.UniqueBlockAgg;
@@ -1056,6 +1057,8 @@ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
addParser("agg_topdocs", new TopDocsAgg.Parser());
+ addParser("agg_termfreqfreq", new TermFrequencyOfFrequenciesAgg.Parser());
+
addParser("childfield", new ChildFieldValueSourceParser());
}
diff --git a/solr/core/src/java/org/apache/solr/search/facet/TermFrequencyCounter.java b/solr/core/src/java/org/apache/solr/search/facet/TermFrequencyCounter.java
new file mode 100644
index 000000000000..4b786ed92418
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/search/facet/TermFrequencyCounter.java
@@ -0,0 +1,41 @@
+package org.apache.solr.search.facet;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import org.apache.solr.common.util.SimpleOrderedMap;
+
+public class TermFrequencyCounter {
+ private final Map counters;
+
+ public TermFrequencyCounter() {
+ this.counters = new HashMap<>();
+ }
+
+ public Map getCounters() {
+ return this.counters;
+ }
+
+ public void add(String value) {
+ counters.merge(value, 1, Integer::sum);
+ }
+
+ public Map serialize(int limit) {
+ if (limit < Integer.MAX_VALUE && limit < counters.size()) {
+ return counters.entrySet()
+ .stream()
+ .sorted((l, r) -> r.getValue() - l.getValue()) // sort by value descending
+ .limit(limit)
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ } else {
+ return counters;
+ }
+ }
+
+ public TermFrequencyCounter merge(Map serialized) {
+ serialized.forEach((value, freq) -> counters.merge(value, freq, Integer::sum));
+
+ return this;
+ }
+}
diff --git a/solr/core/src/java/org/apache/solr/search/facet/TermFrequencyOfFrequenciesAgg.java b/solr/core/src/java/org/apache/solr/search/facet/TermFrequencyOfFrequenciesAgg.java
new file mode 100644
index 000000000000..7726c5648476
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/search/facet/TermFrequencyOfFrequenciesAgg.java
@@ -0,0 +1,74 @@
+package org.apache.solr.search.facet;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import org.apache.lucene.queries.function.ValueSource;
+import org.apache.solr.common.util.SimpleOrderedMap;
+import org.apache.solr.search.FunctionQParser;
+import org.apache.solr.search.SyntaxError;
+import org.apache.solr.search.ValueSourceParser;
+
+public class TermFrequencyOfFrequenciesAgg extends SimpleAggValueSource {
+ private final int termLimit;
+
+ public TermFrequencyOfFrequenciesAgg(ValueSource vs, int termLimit) {
+ super("termfreqfreq", vs);
+
+ this.termLimit = termLimit;
+ }
+
+ @Override
+ public SlotAcc createSlotAcc(FacetContext fcontext, int numDocs, int numSlots) {
+ return new TermFrequencySlotAcc(getArg(), fcontext, numSlots, termLimit);
+ }
+
+ @Override
+ public FacetMerger createFacetMerger(Object prototype) {
+ return new Merger(termLimit);
+ }
+
+ public static class Parser extends ValueSourceParser {
+ @Override
+ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
+ ValueSource vs = fp.parseValueSource();
+
+ int termLimit = Integer.MAX_VALUE;
+ if (fp.hasMoreArguments()) {
+ termLimit = fp.parseInt();
+ }
+
+ return new TermFrequencyOfFrequenciesAgg(vs, termLimit);
+ }
+ }
+
+ private static class Merger extends FacetMerger {
+ private final TermFrequencyCounter result;
+
+ public Merger(int termLimit) {
+ this.result = new TermFrequencyCounter();
+ }
+
+ @Override
+ public void merge(Object facetResult, Context mcontext) {
+ if (facetResult instanceof Map) {
+ result.merge((Map) facetResult);
+ }
+ }
+
+ @Override
+ public void finish(Context mcontext) {
+ // never called
+ }
+
+ @Override
+ public Object getMergedResult() {
+ Map map = new LinkedHashMap<>();
+
+ result.getCounters()
+ .forEach((value, freq) -> map.merge(freq, 1, Integer::sum));
+
+ return map;
+ }
+ }
+}
diff --git a/solr/core/src/java/org/apache/solr/search/facet/TermFrequencySlotAcc.java b/solr/core/src/java/org/apache/solr/search/facet/TermFrequencySlotAcc.java
new file mode 100644
index 000000000000..a3e5b30603ce
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/search/facet/TermFrequencySlotAcc.java
@@ -0,0 +1,52 @@
+package org.apache.solr.search.facet;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.function.IntFunction;
+
+import org.apache.lucene.queries.function.ValueSource;
+
+public class TermFrequencySlotAcc extends FuncSlotAcc {
+ private TermFrequencyCounter[] result;
+ private final int termLimit;
+
+ public TermFrequencySlotAcc(ValueSource values, FacetContext fcontext, int numSlots, int termLimit) {
+ super(values, fcontext, numSlots);
+
+ this.result = new TermFrequencyCounter[numSlots];
+ this.termLimit = termLimit;
+ }
+
+ @Override
+ public void collect(int doc, int slot, IntFunction slotContext) throws IOException {
+ if (result[slot] == null) {
+ result[slot] = new TermFrequencyCounter();
+ }
+ result[slot].add(values.strVal(doc));
+ }
+
+ @Override
+ public int compare(int slotA, int slotB) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Object getValue(int slotNum) {
+ if (result[slotNum] != null) {
+ return result[slotNum].serialize(termLimit);
+ } else {
+ return Collections.emptyList();
+ }
+ }
+
+ @Override
+ public void reset() {
+ Arrays.fill(result, null);
+ }
+
+ @Override
+ public void resize(Resizer resizer) {
+ result = resizer.resize(result, null);
+ }
+}
diff --git a/solr/core/src/test/org/apache/solr/search/facet/TermFrequencyCounterTest.java b/solr/core/src/test/org/apache/solr/search/facet/TermFrequencyCounterTest.java
new file mode 100755
index 000000000000..36a728aaba65
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/search/facet/TermFrequencyCounterTest.java
@@ -0,0 +1,139 @@
+package org.apache.solr.search.facet;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+import com.carrotsearch.randomizedtesting.annotations.Seed;
+import org.apache.lucene.util.LuceneTestCase;
+import org.apache.solr.common.util.JavaBinCodec;
+import org.apache.solr.common.util.SimpleOrderedMap;
+import org.junit.Test;
+
+public class TermFrequencyCounterTest extends LuceneTestCase {
+ private static final char[] ALPHABET = "abcdefghijklkmnopqrstuvwxyz".toCharArray();
+
+ @Test
+ public void testAddValue() throws IOException {
+ int iters = 10 * RANDOM_MULTIPLIER;
+
+ for (int i = 0; i < iters; i++) {
+ TermFrequencyCounter counter = new TermFrequencyCounter();
+
+ int numValues = random().nextInt(100);
+ Map expected = new HashMap<>();
+ for (int j = 0; j < numValues; j++) {
+ String value = randomString(ALPHABET, random().nextInt(256));
+ int count = random().nextInt(256);
+
+ addCount(counter, value, count);
+
+ expected.merge(value, count, Integer::sum);
+ }
+
+ expected.forEach((value, count) -> assertCount(counter, value, count));
+
+ TermFrequencyCounter serialized = serdeser(counter, random().nextInt(Integer.MAX_VALUE));
+
+ expected.forEach((value, count) -> assertCount(serialized, value, count));
+ }
+ }
+
+ @Test
+ public void testMerge() throws IOException {
+ int iters = 10 * RANDOM_MULTIPLIER;
+
+ for (int i = 0; i < iters; i++) {
+ TermFrequencyCounter x = new TermFrequencyCounter();
+
+ int numXValues = random().nextInt(100);
+ Map expectedXValues = new HashMap<>();
+ for (int j = 0; j < numXValues; j++) {
+ String value = randomString(ALPHABET, random().nextInt(256));
+ int count = random().nextInt(256);
+
+ addCount(x, value, count);
+
+ expectedXValues.merge(value, count, Integer::sum);
+ }
+
+ expectedXValues.forEach((value, count) -> assertCount(x, value, count));
+
+ TermFrequencyCounter y = new TermFrequencyCounter();
+
+ int numYValues = random().nextInt(100);
+ Map expectedYValues = new HashMap<>();
+ for (int j = 0; j < numYValues; j++) {
+ String value = randomString(ALPHABET, random().nextInt(256));
+ int count = random().nextInt(256);
+
+ addCount(y, value, count);
+
+ expectedYValues.merge(value, count, Integer::sum);
+ }
+
+ expectedYValues.forEach((value, count) -> assertCount(y, value, count));
+
+ TermFrequencyCounter merged = merge(x, y, random().nextInt(Integer.MAX_VALUE));
+
+ expectedYValues.forEach((value, count) -> expectedXValues.merge(value, count, Integer::sum));
+
+ expectedXValues.forEach((value, count) -> assertCount(merged, value, count));
+ }
+ }
+
+ private static String randomString(char[] alphabet, int length) {
+ final StringBuilder sb = new StringBuilder(length);
+ for (int i = 0; i < length; i++) {
+ sb.append(alphabet[random().nextInt(alphabet.length)]);
+ }
+ return sb.toString();
+ }
+
+ private static void addCount(TermFrequencyCounter counter, String value, int count) {
+ for (int i = 0; i < count; i++) {
+ counter.add(value);
+ }
+ }
+
+ private static void assertCount(TermFrequencyCounter counter, String value, int count) {
+ assertEquals(
+ "value " + value + " should have count " + count,
+ count,
+ (int) counter.getCounters().getOrDefault(value, 0)
+ );
+ }
+
+ private static TermFrequencyCounter serdeser(TermFrequencyCounter counter, int limit) throws IOException {
+ JavaBinCodec codec = new JavaBinCodec();
+
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ codec.marshal(counter.serialize(limit), out);
+
+ InputStream in = new ByteArrayInputStream(out.toByteArray());
+ counter = new TermFrequencyCounter();
+ counter.merge((Map) codec.unmarshal(in));
+
+ return counter;
+ }
+
+ private static TermFrequencyCounter merge(
+ TermFrequencyCounter counter,
+ TermFrequencyCounter toMerge,
+ int limit
+ ) throws IOException {
+ JavaBinCodec codec = new JavaBinCodec();
+
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ codec.marshal(toMerge.serialize(limit), out);
+
+ InputStream in = new ByteArrayInputStream(out.toByteArray());
+ counter.merge((Map) codec.unmarshal(in));
+
+ return counter;
+ }
+}