diff --git a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java index edc21f6bbd..bd7f654ce7 100644 --- a/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java +++ b/frontend/server/src/test/java/org/pytorch/serve/ModelServerTest.java @@ -1349,6 +1349,9 @@ public void testMetricManager() throws JsonParseException, InterruptedException // Define expected metrics // See ts/metrics/system_metrics.py, ts/configs/metrics.yaml Map> expectedMetrics = new HashMap<>(); + expectedMetrics.put("GPUMemoryUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); + expectedMetrics.put("GPUMemoryUsed", Map.of(UNIT, "Megabytes", LEVEL, HOST)); + expectedMetrics.put("GPUUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); expectedMetrics.put("CPUUtilization", Map.of(UNIT, "Percent", LEVEL, HOST)); expectedMetrics.put("MemoryUsed", Map.of(UNIT, "Megabytes", LEVEL, HOST)); expectedMetrics.put("MemoryAvailable", Map.of(UNIT, "Megabytes", LEVEL, HOST)); @@ -1369,7 +1372,8 @@ public void testMetricManager() throws JsonParseException, InterruptedException Assert.assertTrue(++count < 5); } - Assert.assertEquals(metrics.size(), expectedMetrics.size()); + // 7 system-level metrics + 3 gpu-specific metrics + Assert.assertEquals(metrics.size(), 7 + 3 * configManager.getNumberOfGpu()); for (Metric metric : metrics) { String metricName = metric.getMetricName();