Skip to content

Commit

Permalink
Merge branch 'termfreqfreq' into bitmapfrequency
Browse files Browse the repository at this point in the history
  • Loading branch information
mkavanagh committed Sep 14, 2020
2 parents 04c2716 + 4885310 commit 2e49c19
Show file tree
Hide file tree
Showing 13 changed files with 673 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import org.apache.solr.search.facet.StddevAgg;
import org.apache.solr.search.facet.SumAgg;
import org.apache.solr.search.facet.SumsqAgg;
import org.apache.solr.search.facet.TermFrequencyOfFrequenciesAgg;
import org.apache.solr.search.facet.TopDocsAgg;
import org.apache.solr.search.facet.UniqueAgg;
import org.apache.solr.search.facet.UniqueBlockAgg;
Expand Down Expand Up @@ -1071,6 +1072,8 @@ public ValueSource parse(FunctionQParser fp) throws SyntaxError {

addParser("agg_bitmapfreqfreq64", new FrequencyOfFrequenciesAgg64.Parser());

addParser("agg_termfreqfreq", new TermFrequencyOfFrequenciesAgg.Parser());

addParser("childfield", new ChildFieldValueSourceParser());
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.apache.solr.search.facet;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

import org.apache.solr.common.util.SimpleOrderedMap;

public class TermFrequencyCounter {
private final Map<String, Integer> counters;

public TermFrequencyCounter() {
this.counters = new HashMap<>();
}

public Map<String, Integer> getCounters() {
return this.counters;
}

public void add(String value) {
counters.merge(value, 1, Integer::sum);
}

public Map<String, Integer> serialize(int limit) {
if (limit < Integer.MAX_VALUE && limit < counters.size()) {
return counters.entrySet()
.stream()
.sorted((l, r) -> r.getValue() - l.getValue()) // sort by value descending
.limit(limit)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
} else {
return counters;
}
}

public TermFrequencyCounter merge(Map<String, Integer> serialized) {
serialized.forEach((value, freq) -> counters.merge(value, freq, Integer::sum));

return this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package org.apache.solr.search.facet;

import java.util.LinkedHashMap;
import java.util.Map;

import org.apache.lucene.queries.function.ValueSource;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.apache.solr.search.FunctionQParser;
import org.apache.solr.search.SyntaxError;
import org.apache.solr.search.ValueSourceParser;

public class TermFrequencyOfFrequenciesAgg extends SimpleAggValueSource {
private final int termLimit;

public TermFrequencyOfFrequenciesAgg(ValueSource vs, int termLimit) {
super("termfreqfreq", vs);

this.termLimit = termLimit;
}

@Override
public SlotAcc createSlotAcc(FacetContext fcontext, int numDocs, int numSlots) {
return new TermFrequencySlotAcc(getArg(), fcontext, numSlots, termLimit);
}

@Override
public FacetMerger createFacetMerger(Object prototype) {
return new Merger(termLimit);
}

public static class Parser extends ValueSourceParser {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {
ValueSource vs = fp.parseValueSource();

int termLimit = Integer.MAX_VALUE;
if (fp.hasMoreArguments()) {
termLimit = fp.parseInt();
}

return new TermFrequencyOfFrequenciesAgg(vs, termLimit);
}
}

private static class Merger extends FacetMerger {
private final TermFrequencyCounter result;

public Merger(int termLimit) {
this.result = new TermFrequencyCounter();
}

@Override
public void merge(Object facetResult, Context mcontext) {
if (facetResult instanceof Map) {
result.merge((Map<String, Integer>) facetResult);
}
}

@Override
public void finish(Context mcontext) {
// never called
}

@Override
public Object getMergedResult() {
Map<Integer, Integer> map = new LinkedHashMap<>();

result.getCounters()
.forEach((value, freq) -> map.merge(freq, 1, Integer::sum));

return map;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.apache.solr.search.facet;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.function.IntFunction;

import org.apache.lucene.queries.function.ValueSource;

public class TermFrequencySlotAcc extends FuncSlotAcc {
private TermFrequencyCounter[] result;
private final int termLimit;

public TermFrequencySlotAcc(ValueSource values, FacetContext fcontext, int numSlots, int termLimit) {
super(values, fcontext, numSlots);

this.result = new TermFrequencyCounter[numSlots];
this.termLimit = termLimit;
}

@Override
public void collect(int doc, int slot, IntFunction<SlotContext> slotContext) throws IOException {
if (result[slot] == null) {
result[slot] = new TermFrequencyCounter();
}
result[slot].add(values.strVal(doc));
}

@Override
public int compare(int slotA, int slotB) {
throw new UnsupportedOperationException();
}

@Override
public Object getValue(int slotNum) {
if (result[slotNum] != null) {
return result[slotNum].serialize(termLimit);
} else {
return Collections.emptyList();
}
}

@Override
public void reset() {
Arrays.fill(result, null);
}

@Override
public void resize(Resizer resizer) {
result = resizer.resize(result, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@
* the chain and prints them on finish(). At the Debug (FINE) level, a message
* will be logged for each command prior to the next stage in the chain.
* </p>
* <p>
* If the Log level is not &gt;= INFO the processor will not be created or added to the chain.
* </p>
*
* @since solr 1.3
*/
Expand All @@ -62,7 +59,7 @@ public void init( final NamedList args ) {

@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
return log.isInfoEnabled() ? new LogUpdateProcessor(req, rsp, this, next) : null;
return new LogUpdateProcessor(req, rsp, this, next);
}

static class LogUpdateProcessor extends UpdateRequestProcessor {
Expand Down Expand Up @@ -185,6 +182,8 @@ public void finish() throws IOException {

if (log.isInfoEnabled()) {
log.info(getLogStringAndClearRspToLog());
} else {
rsp.getToLog().clear();
}

if (log.isWarnEnabled() && slowUpdateThresholdMillis >= 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package org.apache.solr.search.facet;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

import com.carrotsearch.randomizedtesting.annotations.Seed;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.solr.common.util.JavaBinCodec;
import org.apache.solr.common.util.SimpleOrderedMap;
import org.junit.Test;

public class TermFrequencyCounterTest extends LuceneTestCase {
private static final char[] ALPHABET = "abcdefghijklkmnopqrstuvwxyz".toCharArray();

@Test
public void testAddValue() throws IOException {
int iters = 10 * RANDOM_MULTIPLIER;

for (int i = 0; i < iters; i++) {
TermFrequencyCounter counter = new TermFrequencyCounter();

int numValues = random().nextInt(100);
Map<String, Integer> expected = new HashMap<>();
for (int j = 0; j < numValues; j++) {
String value = randomString(ALPHABET, random().nextInt(256));
int count = random().nextInt(256);

addCount(counter, value, count);

expected.merge(value, count, Integer::sum);
}

expected.forEach((value, count) -> assertCount(counter, value, count));

TermFrequencyCounter serialized = serdeser(counter, random().nextInt(Integer.MAX_VALUE));

expected.forEach((value, count) -> assertCount(serialized, value, count));
}
}

@Test
public void testMerge() throws IOException {
int iters = 10 * RANDOM_MULTIPLIER;

for (int i = 0; i < iters; i++) {
TermFrequencyCounter x = new TermFrequencyCounter();

int numXValues = random().nextInt(100);
Map<String, Integer> expectedXValues = new HashMap<>();
for (int j = 0; j < numXValues; j++) {
String value = randomString(ALPHABET, random().nextInt(256));
int count = random().nextInt(256);

addCount(x, value, count);

expectedXValues.merge(value, count, Integer::sum);
}

expectedXValues.forEach((value, count) -> assertCount(x, value, count));

TermFrequencyCounter y = new TermFrequencyCounter();

int numYValues = random().nextInt(100);
Map<String, Integer> expectedYValues = new HashMap<>();
for (int j = 0; j < numYValues; j++) {
String value = randomString(ALPHABET, random().nextInt(256));
int count = random().nextInt(256);

addCount(y, value, count);

expectedYValues.merge(value, count, Integer::sum);
}

expectedYValues.forEach((value, count) -> assertCount(y, value, count));

TermFrequencyCounter merged = merge(x, y, random().nextInt(Integer.MAX_VALUE));

expectedYValues.forEach((value, count) -> expectedXValues.merge(value, count, Integer::sum));

expectedXValues.forEach((value, count) -> assertCount(merged, value, count));
}
}

private static String randomString(char[] alphabet, int length) {
final StringBuilder sb = new StringBuilder(length);
for (int i = 0; i < length; i++) {
sb.append(alphabet[random().nextInt(alphabet.length)]);
}
return sb.toString();
}

private static void addCount(TermFrequencyCounter counter, String value, int count) {
for (int i = 0; i < count; i++) {
counter.add(value);
}
}

private static void assertCount(TermFrequencyCounter counter, String value, int count) {
assertEquals(
"value " + value + " should have count " + count,
count,
(int) counter.getCounters().getOrDefault(value, 0)
);
}

private static TermFrequencyCounter serdeser(TermFrequencyCounter counter, int limit) throws IOException {
JavaBinCodec codec = new JavaBinCodec();

ByteArrayOutputStream out = new ByteArrayOutputStream();
codec.marshal(counter.serialize(limit), out);

InputStream in = new ByteArrayInputStream(out.toByteArray());
counter = new TermFrequencyCounter();
counter.merge((Map<String, Integer>) codec.unmarshal(in));

return counter;
}

private static TermFrequencyCounter merge(
TermFrequencyCounter counter,
TermFrequencyCounter toMerge,
int limit
) throws IOException {
JavaBinCodec codec = new JavaBinCodec();

ByteArrayOutputStream out = new ByteArrayOutputStream();
codec.marshal(toMerge.serialize(limit), out);

InputStream in = new ByteArrayInputStream(out.toByteArray());
counter.merge((Map<String, Integer>) codec.unmarshal(in));

return counter;
}
}
Loading

0 comments on commit 2e49c19

Please sign in to comment.