diff --git a/include/triton/core/tritonserver.h b/include/triton/core/tritonserver.h index edaf554b2..baa757128 100644 --- a/include/triton/core/tritonserver.h +++ b/include/triton/core/tritonserver.h @@ -2720,6 +2720,18 @@ TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* TRITONSERVER_MetricSet( TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* TRITONSERVER_MetricObserve( struct TRITONSERVER_Metric* metric, double value); +/// Collect metrics. +/// Supports metrics of kind TRITONSERVER_METRIC_KIND_COUNTER, +/// TRITONSERVER_METRIC_KIND_GAUGE, TRITONSERVER_METRIC_KIND_HISTOGRAM and +/// returns TRITONSERVER_ERROR_UNSUPPORTED for unsupported +/// TRITONSERVER_MetricKind. +/// +/// \param metric The metric object to collect. +/// \param value Returns the current value of the metric object. +/// \return a TRITONSERVER_Error indicating success or failure. +TRITONSERVER_DECLSPEC struct TRITONSERVER_Error* TRITONSERVER_MetricCollect( + struct TRITONSERVER_Metric* metric, void* value); + /// Get the TRITONSERVER_MetricKind of metric and its corresponding family. /// /// \param metric The metric object to query. diff --git a/src/metric_family.cc b/src/metric_family.cc index a685514ce..17380ddf0 100644 --- a/src/metric_family.cc +++ b/src/metric_family.cc @@ -76,6 +76,10 @@ MetricFamily::Add( void* prom_metric = nullptr; switch (kind_) { case TRITONSERVER_METRIC_KIND_COUNTER: { + if (buckets != nullptr) { + throw std::invalid_argument( + "Unexpected buckets found in counter Metric constructor."); + } auto counter_family_ptr = reinterpret_cast*>(family_); auto counter_ptr = &counter_family_ptr->Add(label_map); @@ -83,6 +87,10 @@ MetricFamily::Add( break; } case TRITONSERVER_METRIC_KIND_GAUGE: { + if (buckets != nullptr) { + throw std::invalid_argument( + "Unexpected buckets found in gauge Metric constructor."); + } auto gauge_family_ptr = reinterpret_cast*>(family_); auto gauge_ptr = &gauge_family_ptr->Add(label_map); @@ -92,7 +100,7 @@ MetricFamily::Add( case TRITONSERVER_METRIC_KIND_HISTOGRAM: { if (buckets == nullptr) { throw std::invalid_argument( - "Histogram must be constructed with bucket boundaries."); + "Missing required buckets in histogram Metric constructor."); } auto histogram_family_ptr = reinterpret_cast*>(family_); @@ -394,6 +402,40 @@ Metric::Observe(double value) return nullptr; // Success } +TRITONSERVER_Error* +Metric::Collect(prometheus::ClientMetric* value) +{ + if (metric_ == nullptr) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INTERNAL, + "Could not collect metric value. Metric has been invalidated."); + } + + switch (kind_) { + case TRITONSERVER_METRIC_KIND_COUNTER: { + auto counter_ptr = reinterpret_cast(metric_); + *value = counter_ptr->Collect(); + break; + } + case TRITONSERVER_METRIC_KIND_GAUGE: { + auto gauge_ptr = reinterpret_cast(metric_); + *value = gauge_ptr->Collect(); + break; + } + case TRITONSERVER_METRIC_KIND_HISTOGRAM: { + auto histogram_ptr = reinterpret_cast(metric_); + *value = histogram_ptr->Collect(); + break; + } + default: + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, + "Unsupported TRITONSERVER_MetricKind"); + } + + return nullptr; // Success +} + }} // namespace triton::core #endif // TRITON_ENABLE_METRICS diff --git a/src/metric_family.h b/src/metric_family.h index d3e305c76..3565d8fec 100644 --- a/src/metric_family.h +++ b/src/metric_family.h @@ -99,6 +99,7 @@ class Metric { TRITONSERVER_Error* Increment(double value); TRITONSERVER_Error* Set(double value); TRITONSERVER_Error* Observe(double value); + TRITONSERVER_Error* Collect(prometheus::ClientMetric* value); // If a MetricFamily is deleted before its dependent Metric, we want to // invalidate the references so we don't access invalid memory. diff --git a/src/test/metrics_api_test.cc b/src/test/metrics_api_test.cc index 3356493a3..35f6b27b5 100644 --- a/src/test/metrics_api_test.cc +++ b/src/test/metrics_api_test.cc @@ -232,6 +232,32 @@ MetricAPIHelper(TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind kind) TRITONSERVER_ErrorDelete(err); } +void +HistogramAPIHelper(TRITONSERVER_Metric* metric) +{ + // Observe + std::vector data{0.05, 1.5, 6.0}; + std::vector cumulative_counts = {1, 1, 2, 2, 3, 3}; + double sum = 0.0; + for (auto datum : data) { + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricObserve(metric, datum), "observe metric value"); + sum += datum; + } + + // Collect + prometheus::ClientMetric value; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricCollect(metric, &value), + "query metric value after observe"); + auto hist = value.histogram; + ASSERT_EQ(hist.sample_count, data.size()); + ASSERT_EQ(hist.sample_sum, sum); + ASSERT_EQ(hist.bucket.size(), cumulative_counts.size()); + for (uint64_t i = 0; i < hist.bucket.size(); ++i) { + ASSERT_EQ(hist.bucket[i].cumulative_count, cumulative_counts[i]); + } +} // Test Fixture class MetricsApiTest : public ::testing::Test { @@ -364,6 +390,52 @@ TEST_F(MetricsApiTest, TestGaugeEndToEnd) ASSERT_EQ(NumMetricMatches(server_, description), 0); } +// Test end-to-end flow of Generic Metrics API for Histogram metric +TEST_F(MetricsApiTest, TestHistogramEndToEnd) +{ + // Create metric family + TRITONSERVER_MetricFamily* family; + TRITONSERVER_MetricKind kind = TRITONSERVER_METRIC_KIND_HISTOGRAM; + const char* name = "custom_histogram_example"; + const char* description = + "this is an example histogram metric added via API."; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyNew(&family, kind, name, description), + "Creating new metric family"); + + // Create metric + TRITONSERVER_Metric* metric; + std::vector labels; + labels.emplace_back(TRITONSERVER_ParameterNew( + "example1", TRITONSERVER_PARAMETER_STRING, "histogram_label1")); + labels.emplace_back(TRITONSERVER_ParameterNew( + "example2", TRITONSERVER_PARAMETER_STRING, "histogram_label2")); + std::vector buckets = {0.1, 1.0, 2.5, 5.0, 10.0}; + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricNew( + &metric, family, labels.data(), labels.size(), + reinterpret_cast(&buckets)), + "Creating new metric"); + for (const auto label : labels) { + TRITONSERVER_ParameterDelete(const_cast(label)); + } + + // Run through metric APIs and assert correctness + HistogramAPIHelper(metric); + + // Assert custom metric is reported and found in output + ASSERT_EQ(NumMetricMatches(server_, description), 1); + + // Cleanup + FAIL_TEST_IF_ERR(TRITONSERVER_MetricDelete(metric), "delete metric"); + FAIL_TEST_IF_ERR( + TRITONSERVER_MetricFamilyDelete(family), "delete metric family"); + + // Assert custom metric/family is unregistered and no longer in output + ASSERT_EQ(NumMetricMatches(server_, description), 0); +} + + // Test that a duplicate metric family can't be added // with a conflicting type/kind TEST_F(MetricsApiTest, TestDupeMetricFamilyDiffKind) diff --git a/src/tritonserver.cc b/src/tritonserver.cc index a9d0d2f6d..ac36c2445 100644 --- a/src/tritonserver.cc +++ b/src/tritonserver.cc @@ -3463,6 +3463,18 @@ TRITONSERVER_MetricObserve(TRITONSERVER_Metric* metric, double value) #endif // TRITON_ENABLE_METRICS } +TRITONSERVER_Error* +TRITONSERVER_MetricCollect(TRITONSERVER_Metric* metric, void* value) +{ +#ifdef TRITON_ENABLE_METRICS + return reinterpret_cast(metric)->Collect( + reinterpret_cast(value)); +#else + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_UNSUPPORTED, "metrics not supported"); +#endif // TRITON_ENABLE_METRICS +} + TRITONSERVER_Error* TRITONSERVER_GetMetricKind( TRITONSERVER_Metric* metric, TRITONSERVER_MetricKind* kind)