Skip to content

Commit

Permalink
MultiRangeQuery for searching IP masks more 1025 masks in indexed field.
Browse files Browse the repository at this point in the history
Signed-off-by: Mikhail Khludnev <[email protected]>
  • Loading branch information
mkhludnev committed Nov 21, 2024
1 parent d4d70d8 commit 75b2719
Show file tree
Hide file tree
Showing 4 changed files with 375 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Bump `Netty` from 4.1.114.Final to 4.1.115.Final ([#16661](https://github.com/opensearch-project/OpenSearch/pull/16661))

### Changed
- Indexed IP field supports `terms_query` with more than 1025 IP masks [#16391](https://github.com/opensearch-project/OpenSearch/pull/16391)

### Deprecated

Expand Down
149 changes: 117 additions & 32 deletions server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.sandbox.search.MultiRangeQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
Expand All @@ -47,6 +51,7 @@
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.network.NetworkAddress;
import org.opensearch.index.fielddata.IndexFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.plain.SortedSetOrdinalsIndexFieldData;
Expand All @@ -58,13 +63,13 @@
import java.io.IOException;
import java.net.InetAddress;
import java.time.ZoneId;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* A {@link FieldMapper} for ip addresses.
Expand Down Expand Up @@ -262,43 +267,99 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) {
@Override
public Query termsQuery(List<?> values, QueryShardContext context) {
failIfNotIndexedAndNoDocValues();
InetAddress[] addresses = new InetAddress[values.size()];
int i = 0;
for (Object value : values) {
InetAddress address;
if (value instanceof InetAddress) {
address = (InetAddress) value;
} else {
if (value instanceof BytesRef) {
value = ((BytesRef) value).utf8ToString();
Tuple<List<InetAddress>, List<String>> ipsMasks = splitIpsAndMasks(values);
List<Query> combiner = new ArrayList<>();
convertIps(ipsMasks.v1(), combiner);
convertMasks(ipsMasks.v2(), context, combiner);
if (combiner.size() == 1) {
return combiner.get(0);
}
return new ConstantScoreQuery(union(combiner));
}

private Query union(List<Query> combiner) {
BooleanQuery.Builder bqb = new BooleanQuery.Builder();
for (Query q : combiner) {
bqb.add(q, BooleanClause.Occur.SHOULD);
}
return bqb.build();
}

private void convertIps(List<InetAddress> inetAddresses, List<Query> sink) {
if (!inetAddresses.isEmpty() && (isSearchable() || hasDocValues())) {
Query pointsQuery = null;
if (isSearchable()) {
pointsQuery = inetAddresses.size() == 1
? InetAddressPoint.newExactQuery(name(), inetAddresses.iterator().next())
: InetAddressPoint.newSetQuery(name(), inetAddresses.toArray(new InetAddress[0]));
}
Query dvQuery = null;
if (hasDocValues()) {
List<BytesRef> set = new ArrayList<>(inetAddresses.size());
for (final InetAddress address : inetAddresses) {
set.add(new BytesRef(InetAddressPoint.encode(address)));
}
if (value.toString().contains("/")) {
// the `terms` query contains some prefix queries, so we cannot create a set query
// and need to fall back to a disjunction of `term` queries
return super.termsQuery(values, context);
if (set.size() == 1) {
dvQuery = SortedSetDocValuesField.newSlowExactQuery(name(), set.iterator().next());
} else {
dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set);
}
address = InetAddresses.forString(value.toString());
}
addresses[i++] = address;
}
Query dvQuery = null;
if (hasDocValues()) {
List<BytesRef> bytesRefs = Arrays.stream(addresses)
.distinct()
.map(InetAddressPoint::encode)
.map(BytesRef::new)
.collect(Collectors.<BytesRef>toList());
dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), bytesRefs);
final Query out;
if (isSearchable() && hasDocValues()) {
out = new IndexOrDocValuesQuery(pointsQuery, dvQuery);
} else {
out = isSearchable() ? pointsQuery : dvQuery;
}
sink.add(out);
}
Query pointQuery = null;
if (isSearchable()) {
pointQuery = InetAddressPoint.newSetQuery(name(), addresses);
}

private void convertMasks(List<String> masks, QueryShardContext context, List<Query> sink) {
if (!masks.isEmpty() && (isSearchable() || hasDocValues())) {
IpMultiRangeQueryBuilder multiRange = null;
for (String mask : masks) {
final Tuple<InetAddress, Integer> cidr = InetAddresses.parseCidr(mask);
PointRangeQuery query = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2());
if (isSearchable()) { // even there is DV we don't go with it, since we can't guess clauses limit
if (multiRange == null) {
multiRange = new IpMultiRangeQueryBuilder(name());
}
multiRange.add(query.getLowerPoint(), query.getUpperPoint());
} else { // it may hit clauses limit sooner or later
Query dvRange = SortedSetDocValuesField.newSlowRangeQuery(
name(),
new BytesRef(query.getLowerPoint()),
new BytesRef(query.getUpperPoint()),
true,
true
);
sink.add(dvRange);
}
}
// never IndexOrDocValuesQuery() since we can't guess clauses limit
if (multiRange != null) {
sink.add(multiRange.build());
}
}
if (isSearchable() && hasDocValues()) {
return new IndexOrDocValuesQuery(pointQuery, dvQuery);
} else {
return isSearchable() ? pointQuery : dvQuery;
}

private static Tuple<List<InetAddress>, List<String>> splitIpsAndMasks(List<?> values) {
List<InetAddress> concreteIPs = new ArrayList<>();
List<String> masks = new ArrayList<>();
for (final Object value : values) {
if (value instanceof InetAddress) {
concreteIPs.add((InetAddress) value);
} else {
final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString();
if (strVal.contains("/")) {
masks.add(strVal);
} else {
concreteIPs.add(InetAddresses.forString(strVal));
}
}
}
return Tuple.tuple(concreteIPs, masks);
}

@Override
Expand Down Expand Up @@ -445,6 +506,30 @@ public DocValueFormat docValueFormat(@Nullable String format, ZoneId timeZone) {
}
}

/**
* Union over IP address ranges
*/
public static class IpMultiRangeQueryBuilder extends MultiRangeQuery.Builder {
public IpMultiRangeQueryBuilder(String field) {
super(field, InetAddressPoint.BYTES, 1);
}

public IpMultiRangeQueryBuilder add(InetAddress lower, InetAddress upper) {
add(new MultiRangeQuery.RangeClause(InetAddressPoint.encode(lower), InetAddressPoint.encode(upper)));
return this;
}

@Override
public MultiRangeQuery build() {
return new MultiRangeQuery(field, numDims, bytesPerDim, clauses) {
@Override
protected String toString(int dimension, byte[] value) {
return NetworkAddress.format(InetAddressPoint.decode(value));
}
};
}
}

private final boolean indexed;
private final boolean hasDocValues;
private final boolean stored;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,12 @@ public void testTermsQuery() {
);

// if the list includes a prefix query we fallback to a bool query
assertEquals(
new ConstantScoreQuery(
new BooleanQuery.Builder().add(ft.termQuery("::42", null), Occur.SHOULD)
.add(ft.termQuery("::2/16", null), Occur.SHOULD)
.build()
),
ft.termsQuery(Arrays.asList("::42", "::2/16"), null)
);
Query actual = ft.termsQuery(Arrays.asList("::42", "::2/16"), null);
assertTrue(actual instanceof ConstantScoreQuery);
assertTrue(((ConstantScoreQuery) actual).getQuery() instanceof BooleanQuery);
BooleanQuery bq = (BooleanQuery) ((ConstantScoreQuery) actual).getQuery();
assertEquals(2, bq.clauses().size());
assertTrue(bq.clauses().stream().allMatch(c -> c.getOccur() == Occur.SHOULD));
}

public void testDvOnlyTermsQuery() {
Expand All @@ -238,6 +236,14 @@ public void testDvOnlyTermsQuery() {
);
}

public void testDvVsPoint() {
MappedFieldType indexOnly = new IpFieldMapper.IpFieldType("field", true, false, false, null, Collections.emptyMap());
MappedFieldType dvOnly = new IpFieldMapper.IpFieldType("field", false, false, true, null, Collections.emptyMap());
MappedFieldType indexDv = new IpFieldMapper.IpFieldType("field", true, false, true, null, Collections.emptyMap());
assertEquals("ignore DV", indexOnly.termsQuery(List.of("::2/16"), null), indexDv.termsQuery(List.of("::2/16"), null));
assertEquals(dvOnly.termQuery("::2/16", null), dvOnly.termsQuery(List.of("::2/16"), null));
}

public void testRangeQuery() {
MappedFieldType ft = new IpFieldMapper.IpFieldType("field");
Query query = InetAddressPoint.newRangeQuery("field", InetAddresses.forString("::"), InetAddressPoint.MAX_VALUE);
Expand Down
Loading

0 comments on commit 75b2719

Please sign in to comment.