Skip to content

Commit

Permalink
fix bug on serialization when passing task resource usage to coordinator
Browse files Browse the repository at this point in the history
Signed-off-by: Chenyang Ji <[email protected]>
  • Loading branch information
ansjcy authored and deshsidd committed Jul 18, 2024
1 parent 6e57c56 commit 24c930b
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.insights.rules.model;

import org.apache.lucene.util.ArrayUtil;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.tasks.resourcetracker.TaskResourceInfo;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

/**
* Valid attributes for a search query record
*
* @opensearch.internal
*/
public enum Attribute {
/**
* The search query type
*/
SEARCH_TYPE,
/**
* The search query source
*/
SOURCE,
/**
* Total shards queried
*/
TOTAL_SHARDS,
/**
* The indices involved
*/
INDICES,
/**
* The per phase level latency map for a search query
*/
PHASE_LATENCY_MAP,
/**
* The node id for this request
*/
NODE_ID,
/**
* Tasks level resource usages in this request
*/
TASK_RESOURCE_USAGES,
/**
* Custom search request labels
*/
LABELS;

/**
* Read an Attribute from a StreamInput
*
* @param in the StreamInput to read from
* @return Attribute
* @throws IOException IOException
*/
static Attribute readFromStream(final StreamInput in) throws IOException {
return Attribute.valueOf(in.readString().toUpperCase(Locale.ROOT));
}

/**
* Write Attribute to a StreamOutput
*
* @param out the StreamOutput to write
* @param attribute the Attribute to write
* @throws IOException IOException
*/
static void writeTo(final StreamOutput out, final Attribute attribute) throws IOException {
out.writeString(attribute.toString());
}

/**
* Write Attribute value to a StreamOutput
* @param out the StreamOutput to write
* @param attributeValue the Attribute value to write
*/
@SuppressWarnings("unchecked")
public static void writeValueTo(StreamOutput out, Object attributeValue) throws IOException {
if (attributeValue instanceof List) {
out.writeList((List<? extends Writeable>) attributeValue);
} else {
out.writeGenericValue(attributeValue);
}
}

/**
* Read attribute value from the input stream given the Attribute type
*
* @param in the {@link StreamInput} input to read
* @param attribute attribute type to differentiate between Source and others
* @return parse value
* @throws IOException IOException
*/
public static Object readAttributeValue(StreamInput in, Attribute attribute) throws IOException {
if (attribute == Attribute.TASK_RESOURCE_USAGES) {
return in.readList(TaskResourceInfo::readFromStream);
} else {
return in.readGenericValue();
}
}

/**
* Read attribute map from the input stream
*
* @param in the {@link StreamInput} to read
* @return parsed attribute map
* @throws IOException IOException
*/
public static Map<Attribute, Object> readAttributeMap(StreamInput in) throws IOException {
int size = readArraySize(in);
if (size == 0) {
return Collections.emptyMap();
}
Map<Attribute, Object> map = new HashMap<>(size);

for (int i = 0; i < size; i++) {
Attribute key = readFromStream(in);
Object value = readAttributeValue(in, key);
map.put(key, value);
}
return map;
}

private static int readArraySize(StreamInput in) throws IOException {
final int arraySize = in.readVInt();
if (arraySize > ArrayUtil.MAX_ARRAY_LENGTH) {
throw new IllegalStateException("array length must be <= to " + ArrayUtil.MAX_ARRAY_LENGTH + " but was: " + arraySize);
}
if (arraySize < 0) {
throw new NegativeArraySizeException("array size must be positive but was: " + arraySize);
}
return arraySize;
}

@Override
public String toString() {
return this.name().toLowerCase(Locale.ROOT);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.plugin.insights.rules.model;

import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

/**
* SearchQueryRecord represents a minimal atomic record stored in the Query Insight Framework,
* which contains extensive information related to a search query.
*
* @opensearch.internal
*/
public class SearchQueryRecord implements ToXContentObject, Writeable {
private final long timestamp;
private final Map<MetricType, Number> measurements;
private final Map<Attribute, Object> attributes;

/**
* Constructor of SearchQueryRecord
*
* @param in the StreamInput to read the SearchQueryRecord from
* @throws IOException IOException
* @throws ClassCastException ClassCastException
*/
public SearchQueryRecord(final StreamInput in) throws IOException, ClassCastException {
this.timestamp = in.readLong();
measurements = new HashMap<>();
in.readMap(MetricType::readFromStream, StreamInput::readGenericValue)
.forEach(((metricType, o) -> measurements.put(metricType, metricType.parseValue(o))));
this.attributes = Attribute.readAttributeMap(in);
}

/**
* Constructor of SearchQueryRecord
*
* @param timestamp The timestamp of the query.
* @param measurements A list of Measurement associated with this query
* @param attributes A list of Attributes associated with this query
*/
public SearchQueryRecord(final long timestamp, Map<MetricType, Number> measurements, final Map<Attribute, Object> attributes) {
if (measurements == null) {
throw new IllegalArgumentException("Measurements cannot be null");
}
this.measurements = measurements;
this.attributes = attributes;
this.timestamp = timestamp;
}

/**
* Returns the observation time of the metric.
*
* @return the observation time in milliseconds
*/
public long getTimestamp() {
return timestamp;
}

/**
* Returns the measurement associated with the specified name.
*
* @param name the name of the measurement
* @return the measurement object, or null if not found
*/
public Number getMeasurement(final MetricType name) {
return measurements.get(name);
}

/**
* Returns a map of all the measurements associated with the metric.
*
* @return a map of measurement names to measurement objects
*/
public Map<MetricType, Number> getMeasurements() {
return measurements;
}

/**
* Returns a map of the attributes associated with the metric.
*
* @return a map of attribute keys to attribute values
*/
public Map<Attribute, Object> getAttributes() {
return attributes;
}

/**
* Add an attribute to this record
*
* @param attribute attribute to add
* @param value the value associated with the attribute
*/
public void addAttribute(final Attribute attribute, final Object value) {
attributes.put(attribute, value);
}

@Override
public XContentBuilder toXContent(final XContentBuilder builder, final ToXContent.Params params) throws IOException {
builder.startObject();
builder.field("timestamp", timestamp);
for (Map.Entry<Attribute, Object> entry : attributes.entrySet()) {
builder.field(entry.getKey().toString(), entry.getValue());
}
for (Map.Entry<MetricType, Number> entry : measurements.entrySet()) {
builder.field(entry.getKey().toString(), entry.getValue());
}
return builder.endObject();
}

/**
* Write a SearchQueryRecord to a StreamOutput
*
* @param out the StreamOutput to write
* @throws IOException IOException
*/
@Override
public void writeTo(final StreamOutput out) throws IOException {
out.writeLong(timestamp);
out.writeMap(measurements, (stream, metricType) -> MetricType.writeTo(out, metricType), StreamOutput::writeGenericValue);
out.writeMap(
attributes,
(stream, attribute) -> Attribute.writeTo(out, attribute),
(stream, attributeValue) -> Attribute.writeValueTo(out, attributeValue)
);
}

/**
* Compare two SearchQueryRecord, based on the given MetricType
*
* @param a the first SearchQueryRecord to compare
* @param b the second SearchQueryRecord to compare
* @param metricType the MetricType to compare on
* @return 0 if the first SearchQueryRecord is numerically equal to the second SearchQueryRecord;
* -1 if the first SearchQueryRecord is numerically less than the second SearchQueryRecord;
* 1 if the first SearchQueryRecord is numerically greater than the second SearchQueryRecord.
*/
public static int compare(final SearchQueryRecord a, final SearchQueryRecord b, final MetricType metricType) {
return metricType.compare(a.getMeasurement(metricType), b.getMeasurement(metricType));
}

/**
* Check if a SearchQueryRecord is deep equal to another record
*
* @param o the other SearchQueryRecord record
* @return true if two records are deep equal, false otherwise.
*/
@Override
public boolean equals(final Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SearchQueryRecord)) {
return false;
}
final SearchQueryRecord other = (SearchQueryRecord) o;
return timestamp == other.getTimestamp()
&& measurements.equals(other.getMeasurements())
&& attributes.size() == other.getAttributes().size();
}

@Override
public int hashCode() {
return Objects.hash(timestamp, measurements, attributes);
}

@Override
public String toString() {
return Strings.toString(MediaTypeRegistry.JSON, this);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
/**
* Service that helps track resource usage of tasks running on a node.
*/
@PublicApi(since = "2.15.0")
@PublicApi(since = "2.16.0")
@SuppressForbidden(reason = "ThreadMXBean#getThreadAllocatedBytes")
public class TaskResourceTrackingService implements RunnableTaskExecutionListener {

Expand Down Expand Up @@ -359,7 +359,7 @@ public TaskResourceInfo getTaskResourceUsageFromThreadContext() {
/**
* Listener that gets invoked when a task execution completes.
*/
@PublicApi(since = "2.15.0")
@PublicApi(since = "2.16.0")
public interface TaskCompletionListener {
void onTaskCompleted(Task task);
}
Expand Down

0 comments on commit 24c930b

Please sign in to comment.