diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java index dcdb085fdc6fa..80c80fb6b6937 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/Attribute.java @@ -8,11 +8,18 @@ package org.opensearch.plugin.insights.rules.model; +import org.apache.lucene.util.ArrayUtil; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo; import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; import java.util.Locale; +import java.util.Map; /** * Valid attributes for a search query record @@ -75,6 +82,69 @@ static void writeTo(final StreamOutput out, final Attribute attribute) throws IO out.writeString(attribute.toString()); } + /** + * Write Attribute value to a StreamOutput + * @param out the StreamOutput to write + * @param attributeValue the Attribute value to write + */ + @SuppressWarnings("unchecked") + public static void writeValueTo(StreamOutput out, Object attributeValue) throws IOException { + if (attributeValue instanceof List) { + out.writeList((List) attributeValue); + } else { + out.writeGenericValue(attributeValue); + } + } + + /** + * Read attribute value from the input stream given the Attribute type + * + * @param in the {@link StreamInput} input to read + * @param attribute attribute type to differentiate between Source and others + * @return parse value + * @throws IOException IOException + */ + public static Object readAttributeValue(StreamInput in, Attribute attribute) throws IOException { + if (attribute == Attribute.TASK_RESOURCE_USAGES) { + return in.readList(TaskResourceInfo::readFromStream); + } else { + return in.readGenericValue(); + } + } + + /** + * Read attribute map from the input stream + * + * @param in the {@link StreamInput} to read + * @return parsed attribute map + * @throws IOException IOException + */ + public static Map readAttributeMap(StreamInput in) throws IOException { + int size = readArraySize(in); + if (size == 0) { + return Collections.emptyMap(); + } + Map map = new HashMap<>(size); + + for (int i = 0; i < size; i++) { + Attribute key = readFromStream(in); + Object value = readAttributeValue(in, key); + map.put(key, value); + } + return map; + } + + private static int readArraySize(StreamInput in) throws IOException { + final int arraySize = in.readVInt(); + if (arraySize > ArrayUtil.MAX_ARRAY_LENGTH) { + throw new IllegalStateException("array length must be <= to " + ArrayUtil.MAX_ARRAY_LENGTH + " but was: " + arraySize); + } + if (arraySize < 0) { + throw new NegativeArraySizeException("array size must be positive but was: " + arraySize); + } + return arraySize; + } + @Override public String toString() { return this.name().toLowerCase(Locale.ROOT); diff --git a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/SearchQueryRecord.java b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/SearchQueryRecord.java index fec00a680ae58..a6e6b4a9051f0 100644 --- a/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/SearchQueryRecord.java +++ b/plugins/query-insights/src/main/java/org/opensearch/plugin/insights/rules/model/SearchQueryRecord.java @@ -45,7 +45,7 @@ public SearchQueryRecord(final StreamInput in) throws IOException, ClassCastExce measurements = new HashMap<>(); in.readMap(MetricType::readFromStream, StreamInput::readGenericValue) .forEach(((metricType, o) -> measurements.put(metricType, metricType.parseValue(o)))); - this.attributes = in.readMap(Attribute::readFromStream, StreamInput::readGenericValue); + this.attributes = Attribute.readAttributeMap(in); } /** @@ -134,7 +134,11 @@ public XContentBuilder toXContent(final XContentBuilder builder, final ToXConten public void writeTo(final StreamOutput out) throws IOException { out.writeLong(timestamp); out.writeMap(measurements, (stream, metricType) -> MetricType.writeTo(out, metricType), StreamOutput::writeGenericValue); - out.writeMap(attributes, (stream, attribute) -> Attribute.writeTo(out, attribute), StreamOutput::writeGenericValue); + out.writeMap( + attributes, + (stream, attribute) -> Attribute.writeTo(out, attribute), + (stream, attributeValue) -> Attribute.writeValueTo(out, attributeValue) + ); } /** diff --git a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java index ea62093d0c893..80c9e2227e9fe 100644 --- a/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java +++ b/server/src/main/java/org/opensearch/tasks/TaskResourceTrackingService.java @@ -52,7 +52,7 @@ /** * Service that helps track resource usage of tasks running on a node. */ -@PublicApi(since = "2.15.0") +@PublicApi(since = "2.16.0") @SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes") public class TaskResourceTrackingService implements RunnableTaskExecutionListener { @@ -359,7 +359,7 @@ public TaskResourceInfo getTaskResourceUsageFromThreadContext() { /** * Listener that gets invoked when a task execution completes. */ - @PublicApi(since = "2.15.0") + @PublicApi(since = "2.16.0") public interface TaskCompletionListener { void onTaskCompleted(Task task); }