Skip to content

Commit

Permalink
[ML] Add stat for non cache hit inference time (#90464) (#90510)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle authored Sep 29, 2022
1 parent ff7a809 commit 1ca235d
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 26 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/90464.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 90464
summary: Add measure of non cache hit inference count
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,14 @@ The deployment stats for each node that currently has the model allocated.
The average time for each inference call to complete on this node.
The average is calculated over the lifetime of the deployment.
`average_inference_time_ms_excluding_cache_hits`:::
(double)
The average time to perform inference on the trained model excluding
occasions where the response comes from the cache. Cached inference
calls return very quickly as the model is not evaluated, by excluding
cache hits this value is an accurate measure of the average time taken
to evaluate the model.
`average_inference_time_ms_last_minute`:::
(double)
The average time for each inference call to complete on this node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public static class NodeStats implements ToXContentObject, Writeable {
private final DiscoveryNode node;
private final Long inferenceCount;
private final Double avgInferenceTime;
private final Double avgInferenceTimeExcludingCacheHit;
private final Instant lastAccess;
private final Integer pendingCount;
private final int errorCount;
Expand All @@ -51,6 +52,7 @@ public static AssignmentStats.NodeStats forStartedState(
DiscoveryNode node,
long inferenceCount,
Double avgInferenceTime,
Double avgInferenceTimeExcludingCacheHit,
int pendingCount,
int errorCount,
long cacheHitCount,
Expand All @@ -69,6 +71,7 @@ public static AssignmentStats.NodeStats forStartedState(
node,
inferenceCount,
avgInferenceTime,
avgInferenceTimeExcludingCacheHit,
lastAccess,
pendingCount,
errorCount,
Expand All @@ -93,6 +96,7 @@ public static AssignmentStats.NodeStats forNotStartedState(DiscoveryNode node, R
null,
null,
null,
null,
0,
null,
0,
Expand All @@ -112,6 +116,7 @@ public NodeStats(
DiscoveryNode node,
Long inferenceCount,
Double avgInferenceTime,
Double avgInferenceTimeExcludingCacheHit,
@Nullable Instant lastAccess,
Integer pendingCount,
int errorCount,
Expand All @@ -130,6 +135,7 @@ public NodeStats(
this.node = node;
this.inferenceCount = inferenceCount;
this.avgInferenceTime = avgInferenceTime;
this.avgInferenceTimeExcludingCacheHit = avgInferenceTimeExcludingCacheHit;
this.lastAccess = lastAccess;
this.pendingCount = pendingCount;
this.errorCount = errorCount;
Expand Down Expand Up @@ -186,6 +192,12 @@ public NodeStats(StreamInput in) throws IOException {
this.cacheHitCount = null;
this.cacheHitCountLastPeriod = null;
}
if (in.getVersion().onOrAfter(Version.V_8_5_0)) {
this.avgInferenceTimeExcludingCacheHit = in.readOptionalDouble();
} else {
this.avgInferenceTimeExcludingCacheHit = null;
}

}

public DiscoveryNode getNode() {
Expand All @@ -204,6 +216,10 @@ public Optional<Double> getAvgInferenceTime() {
return Optional.ofNullable(avgInferenceTime);
}

public Optional<Double> getAvgInferenceTimeExcludingCacheHit() {
return Optional.ofNullable(avgInferenceTimeExcludingCacheHit);
}

public Instant getLastAccess() {
return lastAccess;
}
Expand Down Expand Up @@ -269,8 +285,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field("inference_count", inferenceCount);
}
// avoid reporting the average time as 0 if count < 1
if (avgInferenceTime != null && (inferenceCount != null && inferenceCount > 0)) {
builder.field("average_inference_time_ms", avgInferenceTime);
if (inferenceCount != null && inferenceCount > 0) {
if (avgInferenceTime != null) {
builder.field("average_inference_time_ms", avgInferenceTime);
}
if (avgInferenceTimeExcludingCacheHit != null) {
builder.field("average_inference_time_ms_excluding_cache_hits", avgInferenceTimeExcludingCacheHit);
}
}
if (cacheHitCount != null) {
builder.field("inference_cache_hit_count", cacheHitCount);
Expand Down Expand Up @@ -337,6 +358,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalVLong(cacheHitCount);
out.writeOptionalVLong(cacheHitCountLastPeriod);
}
if (out.getVersion().onOrAfter(Version.V_8_5_0)) {
out.writeOptionalDouble(avgInferenceTimeExcludingCacheHit);
}
}

@Override
Expand All @@ -346,6 +370,7 @@ public boolean equals(Object o) {
AssignmentStats.NodeStats that = (AssignmentStats.NodeStats) o;
return Objects.equals(inferenceCount, that.inferenceCount)
&& Objects.equals(that.avgInferenceTime, avgInferenceTime)
&& Objects.equals(that.avgInferenceTimeExcludingCacheHit, avgInferenceTimeExcludingCacheHit)
&& Objects.equals(node, that.node)
&& Objects.equals(lastAccess, that.lastAccess)
&& Objects.equals(pendingCount, that.pendingCount)
Expand All @@ -369,6 +394,7 @@ public int hashCode() {
node,
inferenceCount,
avgInferenceTime,
avgInferenceTimeExcludingCacheHit,
lastAccess,
pendingCount,
errorCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ public class GetTrainedModelsStatsActionResponseTests extends AbstractBWCWireSer

@Override
protected Response createTestInstance() {
int listSize = randomInt(10);
// int listSize = randomInt(10);
int listSize = 1;
List<Response.TrainedModelStats> trainedModelStats = Stream.generate(() -> randomAlphaOfLength(10))
.limit(listSize)
.map(
Expand Down Expand Up @@ -123,6 +124,7 @@ protected Response mutateInstanceForVersion(Response instance, Version version)
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
null,
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
0,
Expand Down Expand Up @@ -178,6 +180,7 @@ protected Response mutateInstanceForVersion(Response instance, Version version)
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
null,
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
nodeStats.getErrorCount(),
Expand Down Expand Up @@ -233,6 +236,7 @@ protected Response mutateInstanceForVersion(Response instance, Version version)
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
null,
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
nodeStats.getErrorCount(),
Expand All @@ -258,6 +262,62 @@ protected Response mutateInstanceForVersion(Response instance, Version version)
RESULTS_FIELD
)
);
} else if (version.before(Version.V_8_5_0)) {
return new Response(
new QueryPage<>(
instance.getResources()
.results()
.stream()
.map(
stats -> new Response.TrainedModelStats(
stats.getModelId(),
stats.getModelSizeStats(),
stats.getIngestStats(),
stats.getPipelineCount(),
stats.getInferenceStats(),
stats.getDeploymentStats() == null
? null
: new AssignmentStats(
stats.getDeploymentStats().getModelId(),
stats.getDeploymentStats().getThreadsPerAllocation(),
stats.getDeploymentStats().getNumberOfAllocations(),
stats.getDeploymentStats().getQueueCapacity(),
stats.getDeploymentStats().getCacheSize(),
stats.getDeploymentStats().getStartTime(),
stats.getDeploymentStats()
.getNodeStats()
.stream()
.map(
nodeStats -> new AssignmentStats.NodeStats(
nodeStats.getNode(),
nodeStats.getInferenceCount().orElse(null),
nodeStats.getAvgInferenceTime().orElse(null),
null,
nodeStats.getLastAccess(),
nodeStats.getPendingCount(),
nodeStats.getErrorCount(),
nodeStats.getCacheHitCount().orElse(null),
nodeStats.getRejectedExecutionCount(),
nodeStats.getTimeoutCount(),
nodeStats.getRoutingState(),
nodeStats.getStartTime(),
nodeStats.getThreadsPerAllocation(),
nodeStats.getNumberOfAllocations(),
nodeStats.getPeakThroughput(),
nodeStats.getThroughputLastPeriod(),
nodeStats.getAvgInferenceTimeLastPeriod(),
nodeStats.getCacheHitCountLastPeriod().orElse(null)
)
)
.toList()
)
)
)
.toList(),
instance.getResources().count(),
RESULTS_FIELD
)
);
}
return instance;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,22 @@ public static AssignmentStats.NodeStats randomNodeStats(DiscoveryNode node) {
var lastAccess = Instant.now();
var inferenceCount = randomNonNegativeLong();
Double avgInferenceTime = randomDoubleBetween(0.0, 100.0, true);
Double avgInferenceTimeExcludingCacheHit = randomDoubleBetween(0.0, 100.0, true);
Double avgInferenceTimeLastPeriod = randomDoubleBetween(0.0, 100.0, true);

var noInferenceCallsOnNodeYet = randomBoolean();
if (noInferenceCallsOnNodeYet) {
lastAccess = null;
inferenceCount = 0;
avgInferenceTime = null;
avgInferenceTimeExcludingCacheHit = null;
avgInferenceTimeLastPeriod = null;
}
return AssignmentStats.NodeStats.forStartedState(
node,
inferenceCount,
avgInferenceTime,
avgInferenceTimeExcludingCacheHit,
randomIntBetween(0, 100),
randomIntBetween(0, 100),
randomLongBetween(0, 100),
Expand Down Expand Up @@ -102,6 +105,7 @@ public void testGetOverallInferenceStats() {
new DiscoveryNode("node_started_1", buildNewFakeTransportAddress(), Version.CURRENT),
10L,
randomDoubleBetween(0.0, 100.0, true),
randomDoubleBetween(0.0, 100.0, true),
randomIntBetween(1, 10),
5,
4L,
Expand All @@ -120,6 +124,7 @@ public void testGetOverallInferenceStats() {
new DiscoveryNode("node_started_2", buildNewFakeTransportAddress(), Version.CURRENT),
12L,
randomDoubleBetween(0.0, 100.0, true),
randomDoubleBetween(0.0, 100.0, true),
randomIntBetween(1, 10),
15,
3L,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,9 @@ protected void taskOperation(
nodeStats.add(
AssignmentStats.NodeStats.forStartedState(
clusterService.localNode(),
presentValue.timingStats().getCount(),
presentValue.timingStats().getAverage(),
presentValue.inferenceCount(),
presentValue.averageInferenceTime(),
presentValue.averageInferenceTimeNoCacheHits(),
presentValue.pendingCount(),
presentValue.errorCount(),
presentValue.cacheHitCount(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ public Optional<ModelStats> getStats(TrainedModelDeploymentTask task) {
var recentStats = stats.recentStats();
return new ModelStats(
processContext.startTime,
stats.timingStats(),
stats.timingStats().getCount(),
stats.timingStats().getAverage(),
stats.timingStatsExcludingCacheHits().getAverage(),
stats.lastUsed(),
processContext.executorService.queueSize() + stats.numberOfPendingResults(),
stats.errorCount(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
package org.elasticsearch.xpack.ml.inference.deployment;

import java.time.Instant;
import java.util.LongSummaryStatistics;

public record ModelStats(
Instant startTime,
LongSummaryStatistics timingStats,
long inferenceCount,
Double averageInferenceTime,
Double averageInferenceTimeNoCacheHits,
Instant lastUsed,
int pendingCount,
int errorCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public record RecentStats(long requestsProcessed, Double avgInferenceTime, long

public record ResultStats(
LongSummaryStatistics timingStats,
LongSummaryStatistics timingStatsExcludingCacheHits,
int errorCount,
long cacheHitCount,
int numberOfPendingResults,
Expand All @@ -51,6 +52,7 @@ public record ResultStats(
private final Consumer<ThreadSettings> threadSettingsConsumer;
private volatile boolean isStopping;
private final LongSummaryStatistics timingStats;
private final LongSummaryStatistics timingStatsExcludingCacheHits;
private int errorCount;
private long cacheHitCount;
private long peakThroughput;
Expand All @@ -71,6 +73,7 @@ public PyTorchResultProcessor(String deploymentId, Consumer<ThreadSettings> thre
PyTorchResultProcessor(String deploymentId, Consumer<ThreadSettings> threadSettingsConsumer, LongSupplier currentTimeSupplier) {
this.deploymentId = Objects.requireNonNull(deploymentId);
this.timingStats = new LongSummaryStatistics();
this.timingStatsExcludingCacheHits = new LongSummaryStatistics();
this.lastPeriodSummaryStats = new LongSummaryStatistics();
this.threadSettingsConsumer = Objects.requireNonNull(threadSettingsConsumer);
this.currentTimeMsSupplier = currentTimeSupplier;
Expand Down Expand Up @@ -157,7 +160,7 @@ void processInferenceResult(PyTorchResult result) {
}

logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId()));
processResult(inferenceResult, timeMs, Boolean.TRUE.equals(result.isCacheHit()));
updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId()));
Expand Down Expand Up @@ -235,7 +238,8 @@ public synchronized ResultStats getResultStats() {
}

return new ResultStats(
new LongSummaryStatistics(timingStats.getCount(), timingStats.getMin(), timingStats.getMax(), timingStats.getSum()),
cloneSummaryStats(timingStats),
cloneSummaryStats(timingStatsExcludingCacheHits),
errorCount,
cacheHitCount,
pendingResults.size(),
Expand All @@ -245,7 +249,11 @@ public synchronized ResultStats getResultStats() {
);
}

private synchronized void processResult(PyTorchInferenceResult result, long timeMs, boolean isCacheHit) {
private LongSummaryStatistics cloneSummaryStats(LongSummaryStatistics stats) {
return new LongSummaryStatistics(stats.getCount(), stats.getMin(), stats.getMax(), stats.getSum());
}

private synchronized void updateStats(long timeMs, boolean isCacheHit) {
timingStats.accept(timeMs);

lastResultTimeMs = currentTimeMsSupplier.getAsLong();
Expand Down Expand Up @@ -278,6 +286,9 @@ private synchronized void processResult(PyTorchInferenceResult result, long time
if (isCacheHit) {
cacheHitCount++;
lastPeriodCacheHitCount++;
} else {
// don't include cache hits when recording inference time
timingStatsExcludingCacheHits.accept(timeMs);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ public void testUsage() throws Exception {
new DiscoveryNode("foo", new TransportAddress(TransportAddress.META_ADDRESS, 2), Version.CURRENT),
5,
42.0,
42.0,
0,
1,
3L,
Expand All @@ -399,6 +400,7 @@ public void testUsage() throws Exception {
new DiscoveryNode("bar", new TransportAddress(TransportAddress.META_ADDRESS, 3), Version.CURRENT),
4,
50.0,
50.0,
0,
1,
1L,
Expand Down
Loading

0 comments on commit 1ca235d

Please sign in to comment.