diff --git a/CHANGELOG.md b/CHANGELOG.md index 00754925ea111..13521ad472ee7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Bump `Netty` from 4.1.114.Final to 4.1.115.Final ([#16661](https://github.com/opensearch-project/OpenSearch/pull/16661)) ### Changed +- Indexed IP field supports `terms_query` with more than 1025 IP masks [#16391](https://github.com/opensearch-project/OpenSearch/pull/16391) ### Deprecated diff --git a/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java index c51cada9f3143..97ef0f0a099fc 100644 --- a/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java +++ b/server/src/main/java/org/opensearch/index/mapper/IpFieldMapper.java @@ -36,6 +36,10 @@ import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.StoredField; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.sandbox.search.MultiRangeQuery; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.IndexOrDocValuesQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.PointRangeQuery; @@ -47,6 +51,7 @@ import org.opensearch.common.collect.Tuple; import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.common.network.InetAddresses; +import org.opensearch.common.network.NetworkAddress; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.plain.SortedSetOrdinalsIndexFieldData; @@ -58,13 +63,13 @@ import java.io.IOException; import java.net.InetAddress; import java.time.ZoneId; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.function.BiFunction; import java.util.function.Supplier; -import java.util.stream.Collectors; /** * A {@link FieldMapper} for ip addresses. @@ -262,43 +267,99 @@ public Query termQuery(Object value, @Nullable QueryShardContext context) { @Override public Query termsQuery(List values, QueryShardContext context) { failIfNotIndexedAndNoDocValues(); - InetAddress[] addresses = new InetAddress[values.size()]; - int i = 0; - for (Object value : values) { - InetAddress address; - if (value instanceof InetAddress) { - address = (InetAddress) value; - } else { - if (value instanceof BytesRef) { - value = ((BytesRef) value).utf8ToString(); + Tuple, List> ipsMasks = splitIpsAndMasks(values); + List combiner = new ArrayList<>(); + convertIps(ipsMasks.v1(), combiner); + convertMasks(ipsMasks.v2(), context, combiner); + if (combiner.size() == 1) { + return combiner.get(0); + } + return new ConstantScoreQuery(union(combiner)); + } + + private Query union(List combiner) { + BooleanQuery.Builder bqb = new BooleanQuery.Builder(); + for (Query q : combiner) { + bqb.add(q, BooleanClause.Occur.SHOULD); + } + return bqb.build(); + } + + private void convertIps(List inetAddresses, List sink) { + if (!inetAddresses.isEmpty() && (isSearchable() || hasDocValues())) { + Query pointsQuery = null; + if (isSearchable()) { + pointsQuery = inetAddresses.size() == 1 + ? InetAddressPoint.newExactQuery(name(), inetAddresses.iterator().next()) + : InetAddressPoint.newSetQuery(name(), inetAddresses.toArray(new InetAddress[0])); + } + Query dvQuery = null; + if (hasDocValues()) { + List set = new ArrayList<>(inetAddresses.size()); + for (final InetAddress address : inetAddresses) { + set.add(new BytesRef(InetAddressPoint.encode(address))); } - if (value.toString().contains("/")) { - // the `terms` query contains some prefix queries, so we cannot create a set query - // and need to fall back to a disjunction of `term` queries - return super.termsQuery(values, context); + if (set.size() == 1) { + dvQuery = SortedSetDocValuesField.newSlowExactQuery(name(), set.iterator().next()); + } else { + dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), set); } - address = InetAddresses.forString(value.toString()); } - addresses[i++] = address; - } - Query dvQuery = null; - if (hasDocValues()) { - List bytesRefs = Arrays.stream(addresses) - .distinct() - .map(InetAddressPoint::encode) - .map(BytesRef::new) - .collect(Collectors.toList()); - dvQuery = SortedSetDocValuesField.newSlowSetQuery(name(), bytesRefs); + final Query out; + if (isSearchable() && hasDocValues()) { + out = new IndexOrDocValuesQuery(pointsQuery, dvQuery); + } else { + out = isSearchable() ? pointsQuery : dvQuery; + } + sink.add(out); } - Query pointQuery = null; - if (isSearchable()) { - pointQuery = InetAddressPoint.newSetQuery(name(), addresses); + } + + private void convertMasks(List masks, QueryShardContext context, List sink) { + if (!masks.isEmpty() && (isSearchable() || hasDocValues())) { + IpMultiRangeQueryBuilder multiRange = null; + for (String mask : masks) { + final Tuple cidr = InetAddresses.parseCidr(mask); + PointRangeQuery query = (PointRangeQuery) InetAddressPoint.newPrefixQuery(name(), cidr.v1(), cidr.v2()); + if (isSearchable()) { // even there is DV we don't go with it, since we can't guess clauses limit + if (multiRange == null) { + multiRange = new IpMultiRangeQueryBuilder(name()); + } + multiRange.add(query.getLowerPoint(), query.getUpperPoint()); + } else { // it may hit clauses limit sooner or later + Query dvRange = SortedSetDocValuesField.newSlowRangeQuery( + name(), + new BytesRef(query.getLowerPoint()), + new BytesRef(query.getUpperPoint()), + true, + true + ); + sink.add(dvRange); + } + } + // never IndexOrDocValuesQuery() since we can't guess clauses limit + if (multiRange != null) { + sink.add(multiRange.build()); + } } - if (isSearchable() && hasDocValues()) { - return new IndexOrDocValuesQuery(pointQuery, dvQuery); - } else { - return isSearchable() ? pointQuery : dvQuery; + } + + private static Tuple, List> splitIpsAndMasks(List values) { + List concreteIPs = new ArrayList<>(); + List masks = new ArrayList<>(); + for (final Object value : values) { + if (value instanceof InetAddress) { + concreteIPs.add((InetAddress) value); + } else { + final String strVal = (value instanceof BytesRef) ? ((BytesRef) value).utf8ToString() : value.toString(); + if (strVal.contains("/")) { + masks.add(strVal); + } else { + concreteIPs.add(InetAddresses.forString(strVal)); + } + } } + return Tuple.tuple(concreteIPs, masks); } @Override @@ -445,6 +506,30 @@ public DocValueFormat docValueFormat(@Nullable String format, ZoneId timeZone) { } } + /** + * Union over IP address ranges + */ + public static class IpMultiRangeQueryBuilder extends MultiRangeQuery.Builder { + public IpMultiRangeQueryBuilder(String field) { + super(field, InetAddressPoint.BYTES, 1); + } + + public IpMultiRangeQueryBuilder add(InetAddress lower, InetAddress upper) { + add(new MultiRangeQuery.RangeClause(InetAddressPoint.encode(lower), InetAddressPoint.encode(upper))); + return this; + } + + @Override + public MultiRangeQuery build() { + return new MultiRangeQuery(field, numDims, bytesPerDim, clauses) { + @Override + protected String toString(int dimension, byte[] value) { + return NetworkAddress.format(InetAddressPoint.decode(value)); + } + }; + } + } + private final boolean indexed; private final boolean hasDocValues; private final boolean stored; diff --git a/server/src/test/java/org/opensearch/index/mapper/IpFieldTypeTests.java b/server/src/test/java/org/opensearch/index/mapper/IpFieldTypeTests.java index a5403ef81481f..243164c5fe8fb 100644 --- a/server/src/test/java/org/opensearch/index/mapper/IpFieldTypeTests.java +++ b/server/src/test/java/org/opensearch/index/mapper/IpFieldTypeTests.java @@ -205,14 +205,12 @@ public void testTermsQuery() { ); // if the list includes a prefix query we fallback to a bool query - assertEquals( - new ConstantScoreQuery( - new BooleanQuery.Builder().add(ft.termQuery("::42", null), Occur.SHOULD) - .add(ft.termQuery("::2/16", null), Occur.SHOULD) - .build() - ), - ft.termsQuery(Arrays.asList("::42", "::2/16"), null) - ); + Query actual = ft.termsQuery(Arrays.asList("::42", "::2/16"), null); + assertTrue(actual instanceof ConstantScoreQuery); + assertTrue(((ConstantScoreQuery) actual).getQuery() instanceof BooleanQuery); + BooleanQuery bq = (BooleanQuery) ((ConstantScoreQuery) actual).getQuery(); + assertEquals(2, bq.clauses().size()); + assertTrue(bq.clauses().stream().allMatch(c -> c.getOccur() == Occur.SHOULD)); } public void testDvOnlyTermsQuery() { @@ -238,6 +236,14 @@ public void testDvOnlyTermsQuery() { ); } + public void testDvVsPoint() { + MappedFieldType indexOnly = new IpFieldMapper.IpFieldType("field", true, false, false, null, Collections.emptyMap()); + MappedFieldType dvOnly = new IpFieldMapper.IpFieldType("field", false, false, true, null, Collections.emptyMap()); + MappedFieldType indexDv = new IpFieldMapper.IpFieldType("field", true, false, true, null, Collections.emptyMap()); + assertEquals("ignore DV", indexOnly.termsQuery(List.of("::2/16"), null), indexDv.termsQuery(List.of("::2/16"), null)); + assertEquals(dvOnly.termQuery("::2/16", null), dvOnly.termsQuery(List.of("::2/16"), null)); + } + public void testRangeQuery() { MappedFieldType ft = new IpFieldMapper.IpFieldType("field"); Query query = InetAddressPoint.newRangeQuery("field", InetAddresses.forString("::"), InetAddressPoint.MAX_VALUE); diff --git a/server/src/test/java/org/opensearch/search/SearchIpFieldTermsTests.java b/server/src/test/java/org/opensearch/search/SearchIpFieldTermsTests.java new file mode 100644 index 0000000000000..3a3bd9bdd1609 --- /dev/null +++ b/server/src/test/java/org/opensearch/search/SearchIpFieldTermsTests.java @@ -0,0 +1,243 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search; + +import org.apache.lucene.search.IndexSearcher; +import org.opensearch.action.bulk.BulkRequestBuilder; +import org.opensearch.action.search.SearchPhaseExecutionException; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.network.InetAddresses; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.test.OpenSearchSingleNodeTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Consumer; + +import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.hamcrest.Matchers.equalTo; + +public class SearchIpFieldTermsTests extends OpenSearchSingleNodeTestCase { + + /** + * @return number of expected matches + * */ + private int createIndex(String indexName, int numberOfMasks, List queryTermsSink) throws IOException { + XContentBuilder xcb = createMapping(); + client().admin().indices().prepareCreate(indexName).setMapping(xcb).get(); + ensureGreen(); + + BulkRequestBuilder bulkRequestBuilder = client().prepareBulk(); + + Set dedupeCidrs = new HashSet<>(); + int cidrs = 0; + int ips = 0; + + for (int i = 0; ips <= 10240 && cidrs < numberOfMasks && i < 1000000; i++) { + String ip; + int prefix; + boolean mask; + do { + mask = ips > 0 && random().nextBoolean(); + ip = generateRandomIPv4(); + prefix = 24 + random().nextInt(8); // CIDR prefix for IPv4 + } while (mask && !dedupeCidrs.add(getFirstThreeOctets(ip))); + + bulkRequestBuilder.add( + client().prepareIndex(indexName).setSource(Map.of("addr", ip, "dummy_filter", randomSubsetOf(1, "1", "2", "3"))) + ); + + final String termToQuery; + if (mask) { + termToQuery = ip + "/" + prefix; + cidrs++; + } else { + termToQuery = ip; + ips++; + } + queryTermsSink.add(termToQuery); + } + int addMatches = 0; + for (int i = 0; i < atLeast(100); i++) { + final String ip; + ip = generateRandomIPv4(); + bulkRequestBuilder.add( + client().prepareIndex(indexName).setSource(Map.of("addr", ip, "dummy_filter", randomSubsetOf(1, "1", "2", "3"))) + ); + boolean match = false; + for (String termQ : queryTermsSink) { + boolean isCidr = termQ.contains("/"); + if ((isCidr && isIPInCIDR(ip, termQ)) || (!isCidr && termQ.equals(ip))) { + match = true; + break; + } + } + if (match) { + addMatches++; + } else { + break; // single mismatch is enough. + } + } + + bulkRequestBuilder.setRefreshPolicy(IMMEDIATE).get(); + return ips + cidrs + addMatches; + } + + public void testLessThanMaxClauses() throws IOException { + ArrayList toQuery = new ArrayList<>(); + String indexName = "small"; + int expectMatches = createIndex(indexName, IndexSearcher.getMaxClauseCount() - 1, toQuery); + + assertTermsHitCount(indexName, "addr", toQuery, expectMatches); + assertTermsHitCount(indexName, "addr.idx", toQuery, expectMatches); + assertTermsHitCount(indexName, "addr.dv", toQuery, expectMatches); + // passing dummy filter crushes on rewriting + try { + assertTermsHitCount(indexName, "addr.dv", toQuery, expectMatches, (boolBuilder) -> { + boolBuilder.filter(QueryBuilders.termsQuery("dummy_filter", "1", "2", "3")) + .filter(QueryBuilders.termsQuery("dummy_filter", "1", "2", "3", "4")) + .filter(QueryBuilders.termsQuery("dummy_filter", "1", "2", "3", "4", "5")); + }); + fail(); + } catch (SearchPhaseExecutionException ose) { + assertTrue("exceeding on query rewrite", ose.shardFailures()[0].getCause() instanceof IndexSearcher.TooManyNestedClauses); + } + } + + public void testExceedMaxClauses() throws IOException { + ArrayList toQuery = new ArrayList<>(); + String indexName = "larger"; + int expectMatches = createIndex(indexName, IndexSearcher.getMaxClauseCount() + (rarely() ? 0 : atLeast(10)) // TODO + often some + // more + , toQuery); + assertTermsHitCount(indexName, "addr", toQuery, expectMatches); + assertTermsHitCount(indexName, "addr.idx", toQuery, expectMatches); + try { // error from mapper/parser + assertTermsHitCount(indexName, "addr.dv", toQuery, expectMatches); + fail(); + } catch (SearchPhaseExecutionException ose) { + assertTrue("exceeding on query building", ose.shardFailures()[0].getCause().getCause() instanceof IndexSearcher.TooManyClauses); + } + } + + public static String getFirstThreeOctets(String ipAddress) { + // Split the IP address by the dot delimiter + String[] octets = ipAddress.split("\\."); + + // Take the first three octets + String[] firstThreeOctets = new String[3]; + System.arraycopy(octets, 0, firstThreeOctets, 0, 3); + + // Join the first three octets back together with dots + return String.join(".", firstThreeOctets); + } + + private void assertTermsHitCount(String indexName, String field, Collection toQuery, long expectedMatches) { + assertTermsHitCount(indexName, field, toQuery, expectedMatches, (bqb) -> {}); + } + + private void assertTermsHitCount( + String indexName, + String field, + Collection toQuery, + long expectedMatches, + Consumer addFilter + ) { + TermsQueryBuilder ipTerms = QueryBuilders.termsQuery(field, new ArrayList<>(toQuery)); + BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); + addFilter.accept(boolQueryBuilder); + SearchResponse result = client().prepareSearch(indexName).setQuery(boolQueryBuilder.must(ipTerms) + // .filter(QueryBuilders.termsQuery("dummy_filter", "a", "b")) + ).get(); + long hitsFound = Objects.requireNonNull(result.getHits().getTotalHits()).value; + MatcherAssert.assertThat(field, hitsFound, equalTo(expectedMatches)); + } + + // Converts an IP string (either IPv4 or IPv6) to a byte array + private static byte[] ipToBytes(String ip) { + InetAddress inetAddress = InetAddresses.forString(ip); + return inetAddress.getAddress(); + } + + // Checks if an IP is within a given CIDR (works for both IPv4 and IPv6) + private static boolean isIPInCIDR(String ip, String cidr) { + String[] cidrParts = cidr.split("/"); + String cidrIp = cidrParts[0]; + int prefixLength = Integer.parseInt(cidrParts[1]); + + byte[] ipBytes = ipToBytes(ip); + byte[] cidrIpBytes = ipToBytes(cidrIp); + + // Calculate how many full bytes and how many bits are in the mask + int fullBytes = prefixLength / 8; + int extraBits = prefixLength % 8; + + // Compare full bytes + for (int i = 0; i < fullBytes; i++) { + if (ipBytes[i] != cidrIpBytes[i]) { + return false; + } + } + + // Compare extra bits (if any) + if (extraBits > 0) { + int mask = 0xFF << (8 - extraBits); + return (ipBytes[fullBytes] & mask) == (cidrIpBytes[fullBytes] & mask); + } + + return true; + } + + // Generate a random IPv4 address + private String generateRandomIPv4() { + return String.join( + ".", + String.valueOf(random().nextInt(256)), + String.valueOf(random().nextInt(256)), + String.valueOf(random().nextInt(256)), + String.valueOf(random().nextInt(256)) + ); + } + + private XContentBuilder createMapping() throws IOException { + return XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("addr") + .field("type", "ip") + .startObject("fields") + .startObject("idx") + .field("type", "ip") + .field("doc_values", false) + .endObject() + .startObject("dv") + .field("type", "ip") + .field("index", false) + .endObject() + .endObject() + .endObject() + .startObject("dummy_filter") + .field("type", "keyword") + .endObject() + .endObject() + .endObject(); + } +}