diff --git a/docs/index.md b/docs/index.md index 669fa25e0..1832953be 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,6 +14,7 @@ A Flint index is ... - Partition: skip data scan by maintaining and filtering partitioned column value per file. - MinMax: skip data scan by maintaining lower and upper bound of the indexed column per file. - ValueSet: skip data scan by building a unique value set of the indexed column per file. + - BloomFilter: skip data scan by building a bloom filter of the indexed column per file. - Covering Index: create index for selected columns within the source dataset to improve query performance - Materialized View: enhance query performance by storing precomputed and aggregated data from the source dataset @@ -23,7 +24,8 @@ Please see the following example in which Index Building Logic and Query Rewrite |----------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | Partition | CREATE SKIPPING INDEX
ON alb_logs
(
  year PARTITION,
  month PARTITION,
  day PARTITION,
  hour PARTITION
) | INSERT INTO flint_alb_logs_skipping_index
SELECT
  FIRST(year) AS year,
  FIRST(month) AS month,
  FIRST(day) AS day,
  FIRST(hour) AS hour,
  input_file_name() AS file_path
FROM alb_logs
GROUP BY
  input_file_name() | SELECT *
FROM alb_logs
WHERE year = 2023 AND month = 4
=>
SELECT *
FROM alb_logs (input_files =
  SELECT file_path
  FROM flint_alb_logs_skipping_index
  WHERE year = 2023 AND month = 4
)
WHERE year = 2023 AND month = 4 | | ValueSet | CREATE SKIPPING INDEX
ON alb_logs
(
  elb_status_code VALUE_SET
) | INSERT INTO flint_alb_logs_skipping_index
SELECT
  COLLECT_SET(elb_status_code) AS elb_status_code,
  input_file_name() AS file_path
FROM alb_logs
GROUP BY
  input_file_name() | SELECT *
FROM alb_logs
WHERE elb_status_code = 404
=>
SELECT *
FROM alb_logs (input_files =
  SELECT file_path
  FROM flint_alb_logs_skipping_index
  WHERE ARRAY_CONTAINS(elb_status_code, 404)
)
WHERE elb_status_code = 404 | -| MinMax | CREATE SKIPPING INDEX
ON alb_logs
(
  request_processing_time MIN_MAX
) | INSERT INTO flint_alb_logs_skipping_index
SELECT
  MIN(request_processing_time) AS request_processing_time_min,
  MAX(request_processing_time) AS request_processing_time_max,
  input_file_name() AS file_path
FROM alb_logs
GROUP BY
  input_file_name() | SELECT *
FROM alb_logs
WHERE request_processing_time = 100
=>
SELECT *
FROM alb_logs (input_files =
SELECT file_path
  FROM flint_alb_logs_skipping_index
  WHERE request_processing_time_min <= 100
    AND 100 <= request_processing_time_max
)
WHERE request_processing_time = 100 +| MinMax | CREATE SKIPPING INDEX
ON alb_logs
(
  request_processing_time MIN_MAX
) | INSERT INTO flint_alb_logs_skipping_index
SELECT
  MIN(request_processing_time) AS request_processing_time_min,
  MAX(request_processing_time) AS request_processing_time_max,
  input_file_name() AS file_path
FROM alb_logs
GROUP BY
  input_file_name() | SELECT *
FROM alb_logs
WHERE request_processing_time = 100
=>
SELECT *
FROM alb_logs (input_files =
SELECT file_path
  FROM flint_alb_logs_skipping_index
  WHERE request_processing_time_min <= 100
    AND 100 <= request_processing_time_max
)
WHERE request_processing_time = 100 | +| BloomFilter | CREATE SKIPPING INDEX
ON alb_logs
(
  client_ip BLOOM_FILTER
) | INSERT INTO flint_alb_logs_skipping_index
SELECT
  BLOOM_FILTER_AGG(client_ip) AS client_ip,
  input_file_name() AS file_path
FROM alb_logs
GROUP BY
  input_file_name() | SELECT *
FROM alb_logs
WHERE client_ip = '127.0.0.1'
=>
SELECT *
FROM alb_logs (input_files =
  SELECT file_path
  FROM flint_alb_logs_skipping_index
  WHERE BLOOM_FILTER_MIGHT_CONTAIN(client_ip, '127.0.0.1') = true
)
WHERE client_ip = '127.0.0.1' | ### Flint Index Refresh @@ -73,6 +75,7 @@ For now, Flint Index doesn't define its own data type and uses OpenSearch field | **FlintDataType** | |-------------------| | boolean | +| binary | | long | | integer | | short | @@ -469,6 +472,7 @@ flint.skippingIndex() .addPartitions("year", "month", "day") .addValueSet("elb_status_code") .addMinMax("request_processing_time") + .addBloomFilter("client_ip") .create() flint.refreshIndex("flint_spark_catalog_default_alb_logs_skipping_index") diff --git a/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java b/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java index b69f0e4d0..cbeab0a62 100644 --- a/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java +++ b/flint-core/src/main/java/org/apache/spark/metrics/sink/CloudWatchSink.java @@ -18,6 +18,11 @@ import com.codahale.metrics.MetricFilter; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.ScheduledReporter; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Properties; import java.util.concurrent.TimeUnit; @@ -26,6 +31,8 @@ import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter; import org.opensearch.flint.core.metrics.reporter.DimensionedName; import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException; +import com.fasterxml.jackson.databind.ObjectMapper; + /** * Implementation of the Spark metrics {@link Sink} interface @@ -38,6 +45,7 @@ * @author kmccaw */ public class CloudWatchSink implements Sink { + private static final ObjectMapper objectMapper = new ObjectMapper(); private final ScheduledReporter reporter; @@ -198,12 +206,26 @@ public CloudWatchSink( metricFilter = MetricFilter.ALL; } + final Optional dimensionGroupsProperty = getProperty(properties, PropertyKeys.DIMENSION_GROUPS); + DimensionNameGroups dimensionNameGroups = null; + if (dimensionGroupsProperty.isPresent()) { + try { + dimensionNameGroups = objectMapper.readValue(dimensionGroupsProperty.get(), DimensionNameGroups.class); + } catch (IOException e) { + final String message = String.format( + "Unable to parse value (%s) for the \"%s\" CloudWatchSink metrics property.", + dimensionGroupsProperty.get(), + PropertyKeys.DIMENSION_GROUPS); + throw new InvalidMetricsPropertyException(message, e); + } + } + final AmazonCloudWatchAsync cloudWatchClient = AmazonCloudWatchAsyncClient.asyncBuilder() .withCredentials(awsCredentialsProvider) .withRegion(awsRegion) .build(); - this.reporter = DimensionedCloudWatchReporter.forRegistry(metricRegistry, cloudWatchClient, namespaceProperty.get()) + DimensionedCloudWatchReporter.Builder builder = DimensionedCloudWatchReporter.forRegistry(metricRegistry, cloudWatchClient, namespaceProperty.get()) .convertRatesTo(TimeUnit.SECONDS) .convertDurationsTo(TimeUnit.MILLISECONDS) .filter(metricFilter) @@ -220,8 +242,13 @@ public CloudWatchSink( .withStatisticSet() .withGlobalDimensions() .withShouldParseDimensionsFromName(shouldParseInlineDimensions) - .withShouldAppendDropwizardTypeDimension(shouldAppendDropwizardTypeDimension) - .build(); + .withShouldAppendDropwizardTypeDimension(shouldAppendDropwizardTypeDimension); + + if (dimensionNameGroups != null && dimensionNameGroups.getDimensionGroups() != null) { + builder = builder.withDimensionNameGroups(dimensionNameGroups); + } + + this.reporter = builder.withDimensionNameGroups(dimensionNameGroups).build(); } @Override @@ -262,6 +289,7 @@ private static class PropertyKeys { static final String SHOULD_PARSE_INLINE_DIMENSIONS = "shouldParseInlineDimensions"; static final String SHOULD_APPEND_DROPWIZARD_TYPE_DIMENSION = "shouldAppendDropwizardTypeDimension"; static final String METRIC_FILTER_REGEX = "regex"; + static final String DIMENSION_GROUPS = "dimensionGroups"; } /** @@ -272,4 +300,45 @@ private static class PropertyDefaults { static final TimeUnit POLLING_PERIOD_TIME_UNIT = TimeUnit.MINUTES; static final boolean SHOULD_PARSE_INLINE_DIMENSIONS = false; } + + /** + * Represents a container for grouping dimension names used in metrics reporting. + * This class allows for the organization and storage of dimension names into logical groups, + * facilitating the dynamic construction and retrieval of dimension information for metrics. + */ + public static class DimensionNameGroups { + // Holds the grouping of dimension names categorized under different keys. + private Map>> dimensionGroups = new HashMap<>(); + + /** + * Sets the mapping of dimension groups. Each key in the map represents a category or a type + * of dimension, and the value is a list of dimension name groups, where each group is itself + * a list of dimension names that are logically grouped together. + * + * @param dimensionGroups A map of dimension groups categorized by keys, where each key maps + * to a list of dimension name groups. + */ + public void setDimensionGroups(Map>> dimensionGroups) { + if (dimensionGroups == null) { + final String message = String.format( + "Undefined value for the \"%s\" CloudWatchSink metrics property.", + PropertyKeys.DIMENSION_GROUPS); + throw new InvalidMetricsPropertyException(message); + } + this.dimensionGroups = dimensionGroups; + } + + /** + * Retrieves the current mapping of dimension groups. The structure of the returned map is such + * that each key represents a specific category or type of dimension, and the corresponding value + * is a list of dimension name groups. Each group is a list of dimension names that are logically + * grouped together. + * + * @return A map representing the groups of dimension names categorized by keys. Each key maps + * to a list of lists, where each inner list is a group of related dimension names. + */ + public Map>> getDimensionGroups() { + return dimensionGroups; + } + } } diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/BloomFilter.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/BloomFilter.java new file mode 100644 index 000000000..60aba1d2a --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/BloomFilter.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.field.bloomfilter; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Bloom filter interface inspired by [[org.apache.spark.util.sketch.BloomFilter]] but adapts to + * Flint index use and remove unnecessary API. + */ +public interface BloomFilter { + + /** + * Bloom filter binary format version. + */ + enum Version { + V1(1); + + private final int versionNumber; + + Version(int versionNumber) { + this.versionNumber = versionNumber; + } + + public int getVersionNumber() { + return versionNumber; + } + } + + /** + * @return the number of bits in the underlying bit array. + */ + long bitSize(); + + /** + * Put an item into this bloom filter. + * + * @param item Long value item to insert + * @return true if bits changed which means the item must be first time added to the bloom filter. + * Otherwise, it maybe the first time or not. + */ + boolean put(long item); + + /** + * Merge this bloom filter with another bloom filter. + * + * @param bloomFilter bloom filter to merge + * @return bloom filter after merged + */ + BloomFilter merge(BloomFilter bloomFilter); + + /** + * @param item Long value item to check + * @return true if the item may exist in this bloom filter. Otherwise, it is definitely not exist. + */ + boolean mightContain(long item); + + /** + * Serialize this bloom filter and write it to an output stream. + * + * @param out output stream to write + */ + void writeTo(OutputStream out) throws IOException; +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/BitArray.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/BitArray.java new file mode 100644 index 000000000..2bf36b360 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/BitArray.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.flint.core.field.bloomfilter.classic; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; + +/** + * Bit array. + */ +class BitArray { + private final long[] data; + private long bitCount; + + BitArray(long numBits) { + this(new long[numWords(numBits)]); + } + + BitArray(long[] data) { + this.data = data; + long bitCount = 0; + for (long word : data) { + bitCount += Long.bitCount(word); + } + this.bitCount = bitCount; + } + + /** + * @return array length in bits + */ + long bitSize() { + return (long) data.length * Long.SIZE; + } + + /** + * @param index bit index + * @return whether bits at the given index is set + */ + boolean get(long index) { + return (data[(int) (index >>> 6)] & (1L << index)) != 0; + } + + /** + * Set bits at the given index. + * + * @param index bit index + * @return bit changed or not + */ + boolean set(long index) { + if (!get(index)) { + data[(int) (index >>> 6)] |= (1L << index); + bitCount++; + return true; + } + return false; + } + + /** + * Put another array in this bit array. + * + * @param array other bit array + */ + void putAll(BitArray array) { + assert data.length == array.data.length : "BitArrays must be of equal length when merging"; + long bitCount = 0; + for (int i = 0; i < data.length; i++) { + data[i] |= array.data[i]; + bitCount += Long.bitCount(data[i]); + } + this.bitCount = bitCount; + } + + /** + * Serialize and write out this bit array to the given output stream. + * + * @param out output stream + */ + void writeTo(DataOutputStream out) throws IOException { + out.writeInt(data.length); + for (long datum : data) { + out.writeLong(datum); + } + } + + /** + * Deserialize and read bit array from the given input stream. + * + * @param in input stream + * @return bit array + */ + static BitArray readFrom(DataInputStream in) throws IOException { + int numWords = in.readInt(); + long[] data = new long[numWords]; + for (int i = 0; i < numWords; i++) { + data[i] = in.readLong(); + } + return new BitArray(data); + } + + private static int numWords(long numBits) { + if (numBits <= 0) { + throw new IllegalArgumentException("numBits must be positive, but got " + numBits); + } + long numWords = (long) Math.ceil(numBits / 64.0); + if (numWords > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Can't allocate enough space for " + numBits + " bits"); + } + return (int) numWords; + } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (!(other instanceof BitArray)) return false; + BitArray that = (BitArray) other; + return Arrays.equals(data, that.data); + } + + @Override + public int hashCode() { + return Arrays.hashCode(data); + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java new file mode 100644 index 000000000..f6444fb09 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilter.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.flint.core.field.bloomfilter.classic; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.opensearch.flint.core.field.bloomfilter.BloomFilter; + +/** + * Classic bloom filter implementation inspired by [[org.apache.spark.util.sketch.BloomFilterImpl]] + * but only keep minimal functionality. Bloom filter is serialized in the following format: + *

+ * 1) Version number, always 1 (32 bit) + * 2) Number of hash functions (32 bit) + * 3) Total number of words of the underlying bit array (32 bit) + * 4) The words/longs (numWords * 64 bit) + */ +public class ClassicBloomFilter implements BloomFilter { + + /** + * Bit array + */ + private final BitArray bits; + + /** + * Number of hash function + */ + private final int numHashFunctions; + + public ClassicBloomFilter(int expectedNumItems, double fpp) { + long numBits = optimalNumOfBits(expectedNumItems, fpp); + this.bits = new BitArray(numBits); + this.numHashFunctions = optimalNumOfHashFunctions(expectedNumItems, numBits); + } + + ClassicBloomFilter(BitArray bits, int numHashFunctions) { + this.bits = bits; + this.numHashFunctions = numHashFunctions; + } + + @Override + public long bitSize() { + return bits.bitSize(); + } + + @Override + public boolean put(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + boolean bitsChanged = false; + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + bitsChanged |= bits.set(combinedHash % bitSize); + } + return bitsChanged; + } + + @Override + public BloomFilter merge(BloomFilter other) { + if (!(other instanceof ClassicBloomFilter)) { + throw new IllegalStateException("Cannot merge incompatible bloom filter of class" + + other.getClass().getName()); + } + this.bits.putAll(((ClassicBloomFilter) other).bits); + return this; + } + + @Override + public boolean mightContain(long item) { + int h1 = Murmur3_x86_32.hashLong(item, 0); + int h2 = Murmur3_x86_32.hashLong(item, h1); + + long bitSize = bits.bitSize(); + for (int i = 1; i <= numHashFunctions; i++) { + int combinedHash = h1 + (i * h2); + // Flip all the bits if it's negative (guaranteed positive number) + if (combinedHash < 0) { + combinedHash = ~combinedHash; + } + if (!bits.get(combinedHash % bitSize)) { + return false; + } + } + return true; + } + + @Override + public void writeTo(OutputStream out) throws IOException { + DataOutputStream dos = new DataOutputStream(out); + + dos.writeInt(Version.V1.getVersionNumber()); + dos.writeInt(numHashFunctions); + bits.writeTo(dos); + } + + /** + * Deserialize and read bloom filter from an input stream. + * + * @param in input stream + * @return bloom filter + */ + public static BloomFilter readFrom(InputStream in) throws IOException { + DataInputStream dis = new DataInputStream(in); + + int version = dis.readInt(); + if (version != Version.V1.getVersionNumber()) { + throw new IOException("Unexpected Bloom filter version number (" + version + ")"); + } + int numHashFunctions = dis.readInt(); + BitArray bits = BitArray.readFrom(dis); + return new ClassicBloomFilter(bits, numHashFunctions); + } + + private static int optimalNumOfHashFunctions(long n, long m) { + // (m / n) * log(2), but avoid truncation due to division! + return Math.max(1, (int) Math.round((double) m / n * Math.log(2))); + } + + private static long optimalNumOfBits(long n, double p) { + return (long) (-n * Math.log(p) / (Math.log(2) * Math.log(2))); + } + + @Override + public boolean equals(Object other) { + if (other == this) { + return true; + } + if (!(other instanceof ClassicBloomFilter)) { + return false; + } + ClassicBloomFilter that = (ClassicBloomFilter) other; + return this.numHashFunctions == that.numHashFunctions && this.bits.equals(that.bits); + } + + @Override + public int hashCode() { + return bits.hashCode() * 31 + numHashFunctions; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/Murmur3_x86_32.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/Murmur3_x86_32.java new file mode 100644 index 000000000..b76c3bd88 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/classic/Murmur3_x86_32.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file contains code from the Apache Spark project (original license below). + * It contains modifications, which are licensed as above: + */ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.flint.core.field.bloomfilter.classic; + +/** + * 32-bit Murmur3 hasher. This is based on Guava's Murmur3_32HashFunction. + */ +class Murmur3_x86_32 { + private static final int C1 = 0xcc9e2d51; + private static final int C2 = 0x1b873593; + + /** + * Calculate hash for the given input long. + * + * @param input long value + * @param seed seed + * @return hash value + */ + static int hashLong(long input, int seed) { + int low = (int) input; + int high = (int) (input >>> 32); + + int k1 = mixK1(low); + int h1 = mixH1(seed, k1); + + k1 = mixK1(high); + h1 = mixH1(h1, k1); + + return fmix(h1, 8); + } + + private static int mixK1(int k1) { + k1 *= C1; + k1 = Integer.rotateLeft(k1, 15); + k1 *= C2; + return k1; + } + + private static int mixH1(int h1, int k1) { + h1 ^= k1; + h1 = Integer.rotateLeft(h1, 13); + h1 = h1 * 5 + 0xe6546b64; + return h1; + } + + // Finalization mix - force all bits of a hash block to avalanche + private static int fmix(int h1, int length) { + h1 ^= length; + h1 ^= h1 >>> 16; + h1 *= 0x85ebca6b; + h1 ^= h1 >>> 13; + h1 *= 0xc2b2ae35; + h1 ^= h1 >>> 16; + return h1; + } +} diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java new file mode 100644 index 000000000..ce7136507 --- /dev/null +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionUtils.java @@ -0,0 +1,95 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics.reporter; + +import java.util.Map; +import java.util.function.Function; +import org.apache.commons.lang.StringUtils; + +import com.amazonaws.services.cloudwatch.model.Dimension; + +/** + * Utility class for creating and managing CloudWatch dimensions for metrics reporting in Flint. + * It facilitates the construction of dimensions based on different system properties and environment + * variables, supporting the dynamic tagging of metrics with relevant information like job ID, + * application ID, and more. + */ +public class DimensionUtils { + private static final String DIMENSION_JOB_ID = "jobId"; + private static final String DIMENSION_APPLICATION_ID = "applicationId"; + private static final String DIMENSION_APPLICATION_NAME = "applicationName"; + private static final String DIMENSION_DOMAIN_ID = "domainId"; + private static final String DIMENSION_INSTANCE_ROLE = "instanceRole"; + private static final String UNKNOWN = "UNKNOWN"; + + // Maps dimension names to functions that generate Dimension objects based on specific logic or environment variables + private static final Map> dimensionBuilders = Map.of( + DIMENSION_INSTANCE_ROLE, DimensionUtils::getInstanceRoleDimension, + DIMENSION_JOB_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_JOB_ID", DIMENSION_JOB_ID), + DIMENSION_APPLICATION_ID, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", DIMENSION_APPLICATION_ID), + DIMENSION_APPLICATION_NAME, ignored -> getEnvironmentVariableDimension("SERVERLESS_EMR_APPLICATION_NAME", DIMENSION_APPLICATION_NAME), + DIMENSION_DOMAIN_ID, ignored -> getEnvironmentVariableDimension("FLINT_CLUSTER_NAME", DIMENSION_DOMAIN_ID) + ); + + /** + * Constructs a CloudWatch Dimension object based on the provided dimension name. If a specific + * builder exists for the dimension name, it is used; otherwise, a default dimension is constructed. + * + * @param dimensionName The name of the dimension to construct. + * @param parts Additional information that might be required by specific dimension builders. + * @return A CloudWatch Dimension object. + */ + public static Dimension constructDimension(String dimensionName, String[] metricNameParts) { + if (!doesNameConsistsOfMetricNameSpace(metricNameParts)) { + throw new IllegalArgumentException("The provided metric name parts do not consist of a valid metric namespace."); + } + return dimensionBuilders.getOrDefault(dimensionName, ignored -> getDefaultDimension(dimensionName)) + .apply(metricNameParts); + } + + // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137 + // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName. + public static boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) { + return metricNameParts.length >= 3 + && (metricNameParts[1].equals("driver") || StringUtils.isNumeric(metricNameParts[1])); + } + + /** + * Generates a Dimension object representing the instance role (either executor or driver) based on the + * metric name parts provided. + * + * @param parts An array where the second element indicates the role by being numeric (executor) or not (driver). + * @return A Dimension object with the instance role. + */ + private static Dimension getInstanceRoleDimension(String[] parts) { + String value = StringUtils.isNumeric(parts[1]) ? "executor" : parts[1]; + return new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue(value); + } + + /** + * Constructs a Dimension object using a system environment variable. If the environment variable is not found, + * it uses a predefined "UNKNOWN" value. + * + * @param envVarName The name of the environment variable to use for the dimension's value. + * @param dimensionName The name of the dimension. + * @return A Dimension object populated with the appropriate name and value. + */ + private static Dimension getEnvironmentVariableDimension(String envVarName, String dimensionName) { + String value = System.getenv().getOrDefault(envVarName, UNKNOWN); + return new Dimension().withName(dimensionName).withValue(value); + } + + /** + * Provides a generic mechanism to construct a Dimension object with an environment variable value + * or a default value if the environment variable is not set. + * + * @param dimensionName The name of the dimension for which to retrieve the value. + * @return A Dimension object populated with the dimension name and its corresponding value. + */ + private static Dimension getDefaultDimension(String dimensionName) { + return getEnvironmentVariableDimension(dimensionName, dimensionName); + } +} \ No newline at end of file diff --git a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java index a47fa70ce..e16eb0021 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporter.java @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Date; -import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -48,9 +47,12 @@ import java.util.stream.LongStream; import java.util.stream.Stream; import org.apache.commons.lang.StringUtils; +import org.apache.spark.metrics.sink.CloudWatchSink.DimensionNameGroups; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static org.opensearch.flint.core.metrics.reporter.DimensionUtils.constructDimension; + /** * Reports metrics to Amazon's CloudWatch periodically. *

@@ -84,16 +86,6 @@ public class DimensionedCloudWatchReporter extends ScheduledReporter { // Visible for testing public static final String DIMENSION_SNAPSHOT_STD_DEV = "snapshot-std-dev"; - public static final String DIMENSION_JOB_ID = "jobId"; - - public static final String DIMENSION_APPLICATION_ID = "applicationId"; - - public static final String DIMENSION_DOMAIN_ID = "domainId"; - - public static final String DIMENSION_INSTANCE_ROLE = "instanceRole"; - - public static final String UNKNOWN = "unknown"; - /** * Amazon CloudWatch rejects values that are either too small or too large. * Values must be in the range of 8.515920e-109 to 1.174271e+108 (Base 10) or 2e-360 to 2e360 (Base 2). @@ -103,6 +95,8 @@ public class DimensionedCloudWatchReporter extends ScheduledReporter { private static final double SMALLEST_SENDABLE_VALUE = 8.515920e-109; private static final double LARGEST_SENDABLE_VALUE = 1.174271e+108; + private static Map constructedDimensions; + /** * Each CloudWatch API request may contain at maximum 20 datums */ @@ -133,6 +127,7 @@ private DimensionedCloudWatchReporter(final Builder builder) { this.durationUnit = builder.cwDurationUnit; this.shouldParseDimensionsFromName = builder.withShouldParseDimensionsFromName; this.shouldAppendDropwizardTypeDimension = builder.withShouldAppendDropwizardTypeDimension; + this.constructedDimensions = new ConcurrentHashMap<>(); this.filter = MetricFilter.ALL; } @@ -349,34 +344,89 @@ private void stageMetricDatum(final boolean metricConfigured, // Only submit metrics that show some data, so let's save some money if (metricConfigured && (builder.withZeroValuesSubmission || metricValue > 0)) { final DimensionedName dimensionedName = DimensionedName.decode(metricName); + // Add global dimensions for all metrics final Set dimensions = new LinkedHashSet<>(builder.globalDimensions); - MetricInfo metricInfo = getMetricInfo(dimensionedName); - dimensions.addAll(metricInfo.getDimensions()); if (shouldAppendDropwizardTypeDimension) { dimensions.add(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(dimensionValue)); } - metricData.add(new MetricDatum() - .withTimestamp(new Date(builder.clock.getTime())) - .withValue(cleanMetricValue(metricValue)) - .withMetricName(metricInfo.getMetricName()) - .withDimensions(dimensions) - .withUnit(standardUnit)); + MetricInfo metricInfo = getMetricInfo(dimensionedName, dimensions); + for (Set dimensionSet : metricInfo.getDimensionSets()) { + MetricDatum datum = new MetricDatum() + .withTimestamp(new Date(builder.clock.getTime())) + .withValue(cleanMetricValue(metricValue)) + .withMetricName(metricInfo.getMetricName()) + .withDimensions(dimensionSet) + .withUnit(standardUnit); + metricData.add(datum); + } } } - public MetricInfo getMetricInfo(DimensionedName dimensionedName) { + /** + * Constructs a {@link MetricInfo} object based on the provided {@link DimensionedName} and a set of additional dimensions. + * This method processes the metric name contained within {@code dimensionedName} to potentially modify it based on naming conventions + * and extracts or generates additional dimension sets for detailed metrics reporting. + *

+ * If no specific naming convention is detected, the original set of dimensions is used as is. The method finally encapsulates the metric name + * and the collection of dimension sets in a {@link MetricInfo} object and returns it. + * + * @param dimensionedName An instance of {@link DimensionedName} containing the original metric name and any directly associated dimensions. + * @param dimensions A set of {@link Dimension} objects provided externally that should be associated with the metric. + * @return A {@link MetricInfo} object containing the processed metric name and a list of dimension sets for metrics reporting. + */ + private MetricInfo getMetricInfo(DimensionedName dimensionedName, Set dimensions) { + // Add dimensions from dimensionedName + dimensions.addAll(dimensionedName.getDimensions()); + String metricName = dimensionedName.getName(); String[] parts = metricName.split("\\."); - Set dimensions = new HashSet<>(); - if (doesNameConsistsOfMetricNameSpace(parts)) { + List> dimensionSets = new ArrayList<>(); + if (DimensionUtils.doesNameConsistsOfMetricNameSpace(parts)) { metricName = constructMetricName(parts); - addInstanceRoleDimension(dimensions, parts); + // Get dimension sets corresponding to a specific metric source + constructDimensionSets(dimensionSets, parts); + // Add dimensions constructed above into each of the dimensionSets + for (Set dimensionSet : dimensionSets) { + // Create a copy of each set and add the additional dimensions + dimensionSet.addAll(dimensions); + } + } + + if (dimensionSets.isEmpty()) { + dimensionSets.add(dimensions); + } + return new MetricInfo(metricName, dimensionSets); + } + + /** + * Populates a list of dimension sets based on the metric source name extracted from the metric's parts + * and predefined dimension groupings. This method aims to create detailed and structured dimension + * sets for metrics, enhancing the granularity and relevance of metric reporting. + * + * If no predefined dimension groups exist for the metric source, or if the dimension name groups are + * not initialized, the method exits without modifying the dimension sets list. + * + * @param dimensionSets A list to be populated with sets of {@link Dimension} objects, each representing + * a group of dimensions relevant to the metric's source. + * @param parts An array of strings derived from splitting the metric's name, used to extract information + * like the metric source name and to construct dimensions based on naming conventions. + */ + private void constructDimensionSets(List> dimensionSets, String[] parts) { + String metricSourceName = parts[2]; + if (builder.dimensionNameGroups == null || builder.dimensionNameGroups.getDimensionGroups() == null || !builder.dimensionNameGroups.getDimensionGroups().containsKey(metricSourceName)) { + return; + } + + for (List dimensionNames: builder.dimensionNameGroups.getDimensionGroups().get(metricSourceName)) { + Set dimensions = new LinkedHashSet<>(); + for (String dimensionName: dimensionNames) { + constructedDimensions.putIfAbsent(dimensionName, constructDimension(dimensionName, parts)); + dimensions.add(constructedDimensions.get(dimensionName)); + } + dimensionSets.add(dimensions); } - addDefaultDimensionsForSparkJobMetrics(dimensions); - dimensions.addAll(dimensionedName.getDimensions()); - return new MetricInfo(metricName, dimensions); } /** @@ -393,31 +443,6 @@ private String constructMetricName(String[] metricNameParts) { return Stream.of(metricNameParts).skip(partsToSkip).collect(Collectors.joining(".")); } - // These dimensions are for all metrics - // TODO: Remove EMR-S specific env vars https://github.com/opensearch-project/opensearch-spark/issues/231 - private static void addDefaultDimensionsForSparkJobMetrics(Set dimensions) { - final String jobId = System.getenv().getOrDefault("SERVERLESS_EMR_JOB_ID", UNKNOWN); - final String applicationId = System.getenv().getOrDefault("SERVERLESS_EMR_VIRTUAL_CLUSTER_ID", UNKNOWN); - dimensions.add(new Dimension().withName(DIMENSION_JOB_ID).withValue(jobId)); - dimensions.add(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(applicationId)); - } - - private static void addInstanceRoleDimension(Set dimensions, String[] parts) { - Dimension instanceRoleDimension; - if (StringUtils.isNumeric(parts[1])) { - instanceRoleDimension = new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue("executor"); - } else { - instanceRoleDimension = new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue(parts[1]); - } - dimensions.add(instanceRoleDimension); - } - // This tries to replicate the logic here: https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala#L137 - // Since we don't have access to Spark Configuration here: we are relying on the presence of executorId as part of the metricName. - private boolean doesNameConsistsOfMetricNameSpace(String[] metricNameParts) { - return metricNameParts.length >= 3 - && (metricNameParts[1].equals("driver") || StringUtils.isNumeric(metricNameParts[1])); - } - private void stageMetricDatumWithConvertedSnapshot(final boolean metricConfigured, final String metricName, final Snapshot snapshot, @@ -545,19 +570,19 @@ public String getDesc() { public static class MetricInfo { private String metricName; - private Set dimensions; + private List> dimensionSets; - public MetricInfo(String metricName, Set dimensions) { + public MetricInfo(String metricName, List> dimensionSets) { this.metricName = metricName; - this.dimensions = dimensions; + this.dimensionSets = dimensionSets; } public String getMetricName() { return metricName; } - public Set getDimensions() { - return dimensions; + public List> getDimensionSets() { + return dimensionSets; } } @@ -587,6 +612,7 @@ public static class Builder { private StandardUnit cwRateUnit; private StandardUnit cwDurationUnit; private Set globalDimensions; + private DimensionNameGroups dimensionNameGroups; private final Clock clock; private Builder( @@ -787,6 +813,11 @@ public Builder withShouldAppendDropwizardTypeDimension(final boolean value) { return this; } + public Builder withDimensionNameGroups(final DimensionNameGroups dimensionNameGroups) { + this.dimensionNameGroups = dimensionNameGroups; + return this; + } + /** * Does not actually POST to CloudWatch, logs the {@link PutMetricDataRequest putMetricDataRequest} instead. * {@code false} by default. diff --git a/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTest.java similarity index 62% rename from flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java rename to flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTest.java index 6f87276a8..db2948858 100644 --- a/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTests.java +++ b/flint-core/src/test/java/apache/spark/metrics/sink/CloudWatchSinkTest.java @@ -16,7 +16,10 @@ import java.util.Properties; import org.opensearch.flint.core.metrics.reporter.InvalidMetricsPropertyException; -class CloudWatchSinkTests { +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +class CloudWatchSinkTest { private final MetricRegistry metricRegistry = Mockito.mock(MetricRegistry.class); private final SecurityManager securityManager = Mockito.mock(SecurityManager.class); @@ -71,6 +74,44 @@ void should_throwException_when_pollingTimeUnitPropertyIsInvalid() { Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); } + @Test + void should_throwException_when_DimensionGroupsPropertyIsInvalid() { + final Properties properties = getDefaultValidProperties(); + String jsonString = "{\"dimensionGroups\":[{\"MetricSource1\":{}}, [\"Dimension1\",\"Dimension2\",\"Dimension3\"]]}]}"; + properties.setProperty("dimensionGroups", jsonString); + final Executable executable = () -> { + final CloudWatchSink cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + }; + InvalidMetricsPropertyException exception = Assertions.assertThrows(InvalidMetricsPropertyException.class, executable); + StringBuilder expectedMessageBuilder = new StringBuilder(); + expectedMessageBuilder.append("Unable to parse value (") + .append(jsonString) + .append(") for the \"dimensionGroups\" CloudWatchSink metrics property."); + Assertions.assertEquals(expectedMessageBuilder.toString(), exception.getMessage()); + } + + @Test + public void should_CreateCloudWatchSink_When_dimensionGroupsPropertyIsValid() { + final Properties properties = getDefaultValidProperties(); + String jsonString = "{" + + "\"dimensionGroups\": {" + + "\"MetricSource1\": [[\"DimensionA1\", \"DimensionA2\"], [\"DimensionA1\"]]," + + "\"MetricSource2\": [[\"DimensionB1\"], [\"DimensionB2\", \"DimensionB3\", \"DimensionB4\"]]," + + "\"MetricSource3\": [[\"DimensionC1\", \"DimensionC2\", \"DimensionC3\"], [\"DimensionC4\"], [\"DimensionC5\", \"DimensionC6\"]]" + + "}" + + "}"; + properties.setProperty("dimensionGroups", jsonString); + + CloudWatchSink cloudWatchSink = null; + try { + cloudWatchSink = new CloudWatchSink(properties, metricRegistry, securityManager); + } catch (Exception e) { + fail("Should not have thrown any exception, but threw: " + e.getMessage()); + } + + assertNotNull("CloudWatchSink should be created", cloudWatchSink); + } + private Properties getDefaultValidProperties() { final Properties properties = new Properties(); properties.setProperty("namespace", "namespaceValue"); diff --git a/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilterTest.java b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilterTest.java new file mode 100644 index 000000000..39ca8e98d --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/classic/ClassicBloomFilterTest.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.field.bloomfilter.classic; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import org.junit.Test; +import org.opensearch.flint.core.field.bloomfilter.BloomFilter; + +public class ClassicBloomFilterTest { + + private final ClassicBloomFilter bloomFilter = new ClassicBloomFilter(100, 0.01); + + private static final double ACCEPTABLE_FALSE_POSITIVE_RATE = 0.2; + + @Test + public void shouldReturnNoFalseNegative() { + bloomFilter.put(123L); + bloomFilter.put(456L); + bloomFilter.put(789L); + + // For items added, expect no false negative + assertTrue(bloomFilter.mightContain(123L)); + assertTrue(bloomFilter.mightContain(456L)); + assertTrue(bloomFilter.mightContain(789L)); + } + + @Test + public void shouldReturnFalsePositiveLessThanConfigured() { + bloomFilter.put(123L); + bloomFilter.put(456L); + bloomFilter.put(789L); + + // For items not added, expect false positives much lower than configure 1% + int numElements = 1000; + int falsePositiveCount = 0; + for (int i = 0; i < numElements; i++) { + long element = 1000L + i; + if (bloomFilter.mightContain(element)) { + falsePositiveCount++; + } + } + + double actualFalsePositiveRate = (double) falsePositiveCount / numElements; + assertTrue(actualFalsePositiveRate <= ACCEPTABLE_FALSE_POSITIVE_RATE, + "Actual false positive rate is higher than expected"); + } + + @Test + public void shouldBeTheSameAfterWriteToAndReadFrom() throws IOException { + bloomFilter.put(123L); + bloomFilter.put(456L); + bloomFilter.put(789L); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + bloomFilter.writeTo(out); + InputStream in = new ByteArrayInputStream(out.toByteArray()); + BloomFilter newBloomFilter = ClassicBloomFilter.readFrom(in); + assertEquals(bloomFilter, newBloomFilter); + } +} \ No newline at end of file diff --git a/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java new file mode 100644 index 000000000..7fab8c346 --- /dev/null +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionUtilsTest.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.core.metrics.reporter; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.amazonaws.services.cloudwatch.model.Dimension; +import org.junit.jupiter.api.function.Executable; + +import java.lang.reflect.Field; +import java.util.Map; + +public class DimensionUtilsTest { + private static final String[] parts = {"someMetric", "123", "dummySource"}; + + @Test + void testConstructDimensionThrowsIllegalArgumentException() { + String dimensionName = "InvalidDimension"; + String[] metricNameParts = {}; + + final Executable executable = () -> { + DimensionUtils.constructDimension(dimensionName, metricNameParts); + }; + IllegalArgumentException exception = Assertions.assertThrows(IllegalArgumentException.class, executable); + Assertions.assertEquals("The provided metric name parts do not consist of a valid metric namespace.", exception.getMessage()); + } + @Test + public void testGetInstanceRoleDimensionWithExecutor() { + Dimension result = DimensionUtils.constructDimension("instanceRole", parts); + assertEquals("instanceRole", result.getName()); + assertEquals("executor", result.getValue()); + } + + @Test + public void testGetInstanceRoleDimensionWithRoleName() { + String[] parts = {"someMetric", "driver", "dummySource"}; + Dimension result = DimensionUtils.constructDimension("instanceRole", parts); + assertEquals("instanceRole", result.getName()); + assertEquals("driver", result.getValue()); + } + + @Test + public void testGetDefaultDimensionWithUnknown() { + Dimension result = DimensionUtils.constructDimension("nonExistentDimension", parts); + assertEquals("nonExistentDimension", result.getName()); + assertEquals("UNKNOWN", result.getValue()); + } + + @Test + public void testGetDimensionsFromSystemEnv() throws NoSuchFieldException, IllegalAccessException { + Class classOfMap = System.getenv().getClass(); + Field field = classOfMap.getDeclaredField("m"); + field.setAccessible(true); + Map writeableEnvironmentVariables = (Map)field.get(System.getenv()); + writeableEnvironmentVariables.put("TEST_VAR", "dummy1"); + writeableEnvironmentVariables.put("SERVERLESS_EMR_JOB_ID", "dummy2"); + Dimension result1 = DimensionUtils.constructDimension("TEST_VAR", parts); + assertEquals("TEST_VAR", result1.getName()); + assertEquals("dummy1", result1.getValue()); + Dimension result2 = DimensionUtils.constructDimension("jobId", parts); + assertEquals("jobId", result2.getName()); + assertEquals("dummy2", result2.getValue()); + } +} diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java similarity index 92% rename from flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java rename to flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java index 4774bcc0b..db58993ef 100644 --- a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedCloudWatchReporterTest.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package opensearch.flint.core.metrics.reporter; +package org.opensearch.flint.core.metrics.reporter; import com.amazonaws.services.cloudwatch.AmazonCloudWatchAsyncClient; import com.amazonaws.services.cloudwatch.model.Dimension; @@ -16,8 +16,15 @@ import com.codahale.metrics.Histogram; import com.codahale.metrics.MetricRegistry; import com.codahale.metrics.SlidingWindowReservoir; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; import java.util.Set; + +import org.apache.spark.metrics.sink.CloudWatchSink; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -36,8 +43,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; -import org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter; -import org.opensearch.flint.core.metrics.reporter.DimensionedName; import static com.amazonaws.services.cloudwatch.model.StandardUnit.Count; import static com.amazonaws.services.cloudwatch.model.StandardUnit.Microseconds; @@ -49,16 +54,12 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_APPLICATION_ID; import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_COUNT; import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_GAUGE; -import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_INSTANCE_ROLE; -import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_JOB_ID; import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_NAME_TYPE; import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_MEAN; import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_STD_DEV; import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.DIMENSION_SNAPSHOT_SUMMARY; -import static org.opensearch.flint.core.metrics.reporter.DimensionedCloudWatchReporter.UNKNOWN; @ExtendWith(MockitoExtension.class) @MockitoSettings(strictness = Strictness.LENIENT) @@ -110,9 +111,8 @@ public void shouldReportWithoutGlobalDimensionsWhenGlobalDimensionsNotConfigured final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); - assertThat(dimensions).hasSize(3); + assertThat(dimensions).hasSize(1); assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT)); - assertDefaultDimensionsWithUnknownValue(dimensions); } @@ -124,7 +124,6 @@ public void reportedCounterShouldContainExpectedDimension() throws Exception { final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_COUNT)); - assertDefaultDimensionsWithUnknownValue(dimensions); } @Test @@ -135,7 +134,6 @@ public void reportedGaugeShouldContainExpectedDimension() throws Exception { final List dimensions = firstMetricDatumDimensionsFromCapturedRequest(); assertThat(dimensions).contains(new Dimension().withName(DIMENSION_NAME_TYPE).withValue(DIMENSION_GAUGE)); - assertDefaultDimensionsWithUnknownValue(dimensions); } @Test @@ -483,7 +481,6 @@ public void shouldReportExpectedGlobalAndCustomDimensions() throws Exception { assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); assertThat(dimensions).contains(new Dimension().withName("key1").withValue("value1")); assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2")); - assertDefaultDimensionsWithUnknownValue(dimensions); } @Test @@ -495,16 +492,14 @@ public void shouldParseDimensionedNamePrefixedWithMetricNameSpaceDriverMetric() .build().encode()).inc(); reporterBuilder.withGlobalDimensions("Region=us-west-2").build().report(); final DimensionedCloudWatchReporter.MetricInfo metricInfo = firstMetricDatumInfoFromCapturedRequest(); - Set dimensions = metricInfo.getDimensions(); + List> dimensionSets = metricInfo.getDimensionSets(); + Set dimensions = dimensionSets.get(0); assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); assertThat(dimensions).contains(new Dimension().withName("key1").withValue("value1")); assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2")); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN)); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN)); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue("driver")); assertThat(metricInfo.getMetricName()).isEqualTo("LiveListenerBus.listenerProcessingTime.org.apache.spark.HeartbeatReceiver"); } - @Test + @Test public void shouldParseDimensionedNamePrefixedWithMetricNameSpaceExecutorMetric() throws Exception { //setting jobId as unknown to invoke name parsing. metricRegistry.counter(DimensionedName.withName("unknown.1.NettyBlockTransfer.shuffle-client.usedDirectMemory") @@ -514,23 +509,44 @@ public void shouldParseDimensionedNamePrefixedWithMetricNameSpaceExecutorMetric( reporterBuilder.withGlobalDimensions("Region=us-west-2").build().report(); final DimensionedCloudWatchReporter.MetricInfo metricInfo = firstMetricDatumInfoFromCapturedRequest(); - Set dimensions = metricInfo.getDimensions(); + Set dimensions = metricInfo.getDimensionSets().get(0); assertThat(dimensions).contains(new Dimension().withName("Region").withValue("us-west-2")); assertThat(dimensions).contains(new Dimension().withName("key1").withValue("value1")); assertThat(dimensions).contains(new Dimension().withName("key2").withValue("value2")); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_INSTANCE_ROLE).withValue( "executor")); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN)); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN)); assertThat(metricInfo.getMetricName()).isEqualTo("NettyBlockTransfer.shuffle-client.usedDirectMemory"); } + @Test + public void shouldConsumeMultipleMetricDatumWithDimensionGroups() throws Exception { + // Setup + String metricSourceName = "TestSource"; + Map>> dimensionGroups = new HashMap<>(); + dimensionGroups.put(metricSourceName, Arrays.asList( + Arrays.asList("appName", "instanceRole"), + Arrays.asList("appName") + )); + + metricRegistry.counter(DimensionedName.withName("unknown.1.TestSource.shuffle-client.usedDirectMemory") + .build().encode()).inc(); + + CloudWatchSink.DimensionNameGroups dimensionNameGroups = new CloudWatchSink.DimensionNameGroups(); + dimensionNameGroups.setDimensionGroups(dimensionGroups); + reporterBuilder.withDimensionNameGroups(dimensionNameGroups).build().report(); + final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); + final List metricDatums = putMetricDataRequest.getMetricData(); + assertThat(metricDatums).hasSize(2); - private void assertDefaultDimensionsWithUnknownValue(List dimensions) { - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_JOB_ID).withValue(UNKNOWN)); - assertThat(dimensions).contains(new Dimension().withName(DIMENSION_APPLICATION_ID).withValue(UNKNOWN)); - } + MetricDatum metricDatum1 = metricDatums.get(0); + Set dimensions1 = new HashSet(metricDatum1.getDimensions()); + assertThat(dimensions1).contains(new Dimension().withName("appName").withValue("UNKNOWN")); + assertThat(dimensions1).contains(new Dimension().withName("instanceRole").withValue("executor")); + MetricDatum metricDatum2 = metricDatums.get(1); + Set dimensions2 = new HashSet(metricDatum2.getDimensions()); + assertThat(dimensions2).contains(new Dimension().withName("appName").withValue("UNKNOWN")); + assertThat(dimensions2).doesNotContain(new Dimension().withName("instanceRole").withValue("executor")); + } private MetricDatum metricDatumByDimensionFromCapturedRequest(final String dimensionValue) { final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); @@ -564,7 +580,10 @@ private List firstMetricDatumDimensionsFromCapturedRequest() { private DimensionedCloudWatchReporter.MetricInfo firstMetricDatumInfoFromCapturedRequest() { final PutMetricDataRequest putMetricDataRequest = metricDataRequestCaptor.getValue(); final MetricDatum metricDatum = putMetricDataRequest.getMetricData().get(0); - return new DimensionedCloudWatchReporter.MetricInfo(metricDatum.getMetricName(), new HashSet<>(metricDatum.getDimensions())); + Set dimensions = new HashSet(metricDatum.getDimensions()); + List> dimensionSet = new ArrayList<>(); + dimensionSet.add(dimensions); + return new DimensionedCloudWatchReporter.MetricInfo(metricDatum.getMetricName(), dimensionSet); } private List allDimensionsFromCapturedRequest() { diff --git a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java similarity index 97% rename from flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java rename to flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java index d6145545d..6bc6a9c2d 100644 --- a/flint-core/src/test/java/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/metrics/reporter/DimensionedNameTest.java @@ -1,4 +1,4 @@ -package opensearch.flint.core.metrics.reporter; +package org.opensearch.flint.core.metrics.reporter; import static org.hamcrest.CoreMatchers.hasItems; diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala index 6cd5b3352..7a783a610 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkIndexFactory.scala @@ -17,7 +17,8 @@ import org.opensearch.flint.spark.mv.FlintSparkMaterializedView.MV_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind -import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET} +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, MIN_MAX, PARTITION, VALUE_SET} +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy @@ -59,6 +60,11 @@ object FlintSparkIndexFactory { params = parameters) case MIN_MAX => MinMaxSkippingStrategy(columnName = columnName, columnType = columnType) + case BLOOM_FILTER => + BloomFilterSkippingStrategy( + columnName = columnName, + columnType = columnType, + params = parameters) case other => throw new IllegalStateException(s"Unknown skipping strategy: $other") } diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala index ae6518bf0..c27a7f7e2 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingIndex.scala @@ -12,6 +12,7 @@ import org.opensearch.flint.spark._ import org.opensearch.flint.spark.FlintSparkIndex._ import org.opensearch.flint.spark.FlintSparkIndexOptions.empty import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE} +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy import org.opensearch.flint.spark.skipping.minmax.MinMaxSkippingStrategy import org.opensearch.flint.spark.skipping.partition.PartitionSkippingStrategy import org.opensearch.flint.spark.skipping.valueset.ValueSetSkippingStrategy @@ -188,6 +189,25 @@ object FlintSparkSkippingIndex { this } + /** + * Add bloom filter skipping index column. + * + * @param colName + * indexed column name + * @param params + * bloom filter parameters + * @return + * index builder + */ + def addBloomFilter(colName: String, params: Map[String, String] = Map.empty): Builder = { + val col = findColumn(colName) + indexedColumns = indexedColumns :+ BloomFilterSkippingStrategy( + columnName = col.name, + columnType = col.dataType, + params = params) + this + } + override def buildIndex(): FlintSparkIndex = new FlintSparkSkippingIndex(tableName, indexedColumns, indexOptions) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala index 06b6daa13..de2ea772d 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/FlintSparkSkippingStrategy.scala @@ -71,7 +71,7 @@ object FlintSparkSkippingStrategy { type SkippingKind = Value // Use Value[s]Set because ValueSet already exists in Enumeration - val PARTITION, VALUE_SET, MIN_MAX = Value + val PARTITION, VALUE_SET, MIN_MAX, BLOOM_FILTER = Value } /** json4s doesn't serialize Enum by default */ diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala new file mode 100644 index 000000000..b40554335 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterAgg.scala @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.skipping.bloomfilter + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.opensearch.flint.core.field.bloomfilter.BloomFilter +import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} +import org.apache.spark.sql.types.{BinaryType, DataType} + +/** + * An aggregate function that builds a bloom filter and serializes it to binary as the result. + * This implementation is a customized version inspired by Spark's built-in + * [[org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate]]. + * + * The reason of not reusing Spark's implementation include it only accepts expected number of + * bits, it couples with its own BloomFilterImpl and most importantly it cannot be extended due to + * Scala case class restriction. + * + * @param child + * child expression that generate Long values for creating a bloom filter + * @param expectedNumItems + * expected maximum unique number of items + * @param fpp + * false positive probability + */ +case class BloomFilterAgg( + child: Expression, + expectedNumItems: Int, + fpp: Double, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[BloomFilter] { + + def this(child: Expression, expectedNumItems: Int, fpp: Double) = { + this(child, expectedNumItems, fpp, 0, 0) + } + + override def nullable: Boolean = true + + override def dataType: DataType = BinaryType + + override def children: Seq[Expression] = Seq(child) + + override def createAggregationBuffer(): BloomFilter = { + new ClassicBloomFilter(expectedNumItems, fpp) + } + + override def update(buffer: BloomFilter, inputRow: InternalRow): BloomFilter = { + val value = child.eval(inputRow) + if (value == null) { // Ignore null values + return buffer + } + buffer.put(value.asInstanceOf[Long]) + buffer + } + + override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter = { + buffer.merge(input) + buffer + } + + override def eval(buffer: BloomFilter): Any = { + if (buffer.bitSize() == 0) { + // There's no set bit in the Bloom filter and hence no not-null value is processed. + return null + } + serialize(buffer) + } + + override def serialize(buffer: BloomFilter): Array[Byte] = { + // Preallocate space. BloomFilter.writeTo() writes 2 integers (version number and + // num hash functions) first, hence +8 + val size = (buffer.bitSize() / 8) + 8 + require(size <= Integer.MAX_VALUE, s"actual number of bits is too large $size") + val out = new ByteArrayOutputStream(size.intValue()) + buffer.writeTo(out) + out.close() + out.toByteArray + } + + override def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + val bloomFilter = ClassicBloomFilter.readFrom(in) + in.close() + bloomFilter + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = + copy(child = newChildren.head) + + override def withNewMutableAggBufferOffset(newOffset: Int): ImperativeAggregate = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): ImperativeAggregate = + copy(inputAggBufferOffset = newOffset) +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala new file mode 100644 index 000000000..73b03ef0f --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategy.scala @@ -0,0 +1,70 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.skipping.bloomfilter + +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy +import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{BLOOM_FILTER, SkippingKind} +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS} + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.functions.{col, xxhash64} + +/** + * Skipping strategy based on approximate data structure bloom filter. + */ +case class BloomFilterSkippingStrategy( + override val kind: SkippingKind = BLOOM_FILTER, + override val columnName: String, + override val columnType: String, + params: Map[String, String] = Map.empty) + extends FlintSparkSkippingStrategy { + + override val parameters: Map[String, String] = { + Map( + CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> expectedNumItems.toString, + CLASSIC_BLOOM_FILTER_FPP_KEY -> fpp.toString) + } + + override def outputSchema(): Map[String, String] = Map(columnName -> "binary") + + override def getAggregators: Seq[Expression] = { + Seq( + new BloomFilterAgg(xxhash64(col(columnName)).expr, expectedNumItems, fpp) + .toAggregateExpression() + ) // TODO: use xxhash64() for now + } + + override def rewritePredicate(predicate: Expression): Option[Expression] = None + + private def expectedNumItems: Int = { + params + .get(CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY) + .map(_.toInt) + .getOrElse(DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS) + } + + private def fpp: Double = { + params + .get(CLASSIC_BLOOM_FILTER_FPP_KEY) + .map(_.toDouble) + .getOrElse(DEFAULT_CLASSIC_BLOOM_FILTER_FPP) + } +} + +object BloomFilterSkippingStrategy { + + /** + * Expected number of unique items key and default value. + */ + val CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY = "num_items" + val DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS = 10000 + + /** + * False positive probability (FPP) key and default value. + */ + val CLASSIC_BLOOM_FILTER_FPP_KEY = "fpp" + val DEFAULT_CLASSIC_BLOOM_FILTER_FPP = 0.03 +} diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala new file mode 100644 index 000000000..c3db6fb1d --- /dev/null +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/skipping/bloomfilter/BloomFilterSkippingStrategySuite.scala @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.skipping.bloomfilter + +import org.opensearch.flint.spark.skipping.{FlintSparkSkippingStrategy, FlintSparkSkippingStrategySuite} +import org.opensearch.flint.spark.skipping.bloomfilter.BloomFilterSkippingStrategy.{CLASSIC_BLOOM_FILTER_FPP_KEY, CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY, DEFAULT_CLASSIC_BLOOM_FILTER_FPP, DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.FlintSuite + +class BloomFilterSkippingStrategySuite + extends FlintSuite + with FlintSparkSkippingStrategySuite + with Matchers { + + /** Subclass initializes strategy class to test */ + override val strategy: FlintSparkSkippingStrategy = + BloomFilterSkippingStrategy(columnName = "name", columnType = "string") + + test("parameters") { + strategy.parameters shouldBe Map( + CLASSIC_BLOOM_FILTER_NUM_ITEMS_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_NUM_ITEMS.toString, + CLASSIC_BLOOM_FILTER_FPP_KEY -> DEFAULT_CLASSIC_BLOOM_FILTER_FPP.toString) + } +} diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala index 7a6ba99e1..c3b8caffd 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkSkippingIndexITSuite.scala @@ -50,6 +50,7 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { .addPartitions("year", "month") .addValueSet("address") .addMinMax("age") + .addBloomFilter("name") .create() val index = flint.describeIndex(testIndex) @@ -83,6 +84,15 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "parameters": {}, | "columnName": "age", | "columnType": "int" + | }, + | { + | "kind": "BLOOM_FILTER", + | "parameters": { + | "num_items": "10000", + | "fpp": "0.03" + | }, + | "columnName": "name", + | "columnType": "string" | }], | "source": "spark_catalog.default.test", | "options": { @@ -107,6 +117,9 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { | "MinMax_age_1" : { | "type": "integer" | }, + | "name" : { + | "type": "binary" + | }, | "file_path": { | "type": "keyword" | } @@ -366,6 +379,20 @@ class FlintSparkSkippingIndexITSuite extends FlintSparkSuite { hasIndexFilter(col("MinMax_age_0") <= 30 && col("MinMax_age_1") >= 30)) } + test("can build bloom filter skipping index and rewrite applicable query") { + flint + .skippingIndex() + .onTable(testTable) + .addBloomFilter("age") + .create() + flint.refreshIndex(testIndex) + + // Assert index data + flint.queryIndex(testIndex).collect() should have size 2 + + // TODO: Assert query rewrite result + } + test("should rewrite applicable query with table name without database specified") { flint .skippingIndex()