From 5c5a9f161855eafd9313db9c00ab24ca02abb6f3 Mon Sep 17 00:00:00 2001 From: Darnell Andries Date: Thu, 29 Jun 2023 18:02:54 -0700 Subject: [PATCH] Refactor measurement reporter --- src/aggregator/group.rs | 6 +- src/aggregator/mod.rs | 2 +- src/aggregator/processing.rs | 26 ++--- src/aggregator/recovered.rs | 14 +-- src/aggregator/report.rs | 221 ++++++++++++++++------------------- src/models/mod.rs | 2 +- src/models/pending_msg.rs | 8 +- src/models/recovered_msg.rs | 10 +- 8 files changed, 134 insertions(+), 155 deletions(-) diff --git a/src/aggregator/group.rs b/src/aggregator/group.rs index dfcb4b3..2d8d270 100644 --- a/src/aggregator/group.rs +++ b/src/aggregator/group.rs @@ -87,7 +87,7 @@ impl GroupedMessages { for tag in tags { let msgs = - PendingMessage::list(conn.clone(), epoch, tag.clone(), profiler.clone()).await?; + PendingMessage::list(conn.clone(), epoch, tag.clone(), profiler.as_ref()).await?; pending_msgs.insert(tag, msgs); } Ok(pending_msgs) @@ -132,7 +132,7 @@ impl GroupedMessages { for new_msgs in new_pending_msgs.chunks(INSERT_BATCH_SIZE) { let new_msgs = new_msgs.to_vec(); new_msgs - .insert_batch(store_conns.get(), profiler.clone()) + .insert_batch(store_conns.get(), profiler.as_ref()) .await?; } } @@ -337,7 +337,7 @@ mod tests { let db_pool = Arc::new(DBPool::new(true)); let conn = Arc::new(Mutex::new(db_pool.get().await.unwrap())); new_rec_msgs - .insert_batch(conn.clone(), profiler.clone()) + .insert_batch(conn.clone(), profiler.as_ref()) .await .unwrap(); drop(conn); diff --git a/src/aggregator/mod.rs b/src/aggregator/mod.rs index 0f25101..d20d721 100644 --- a/src/aggregator/mod.rs +++ b/src/aggregator/mod.rs @@ -173,7 +173,7 @@ pub async fn start_aggregation( db_conn.clone(), &epoch_config, out_stream.as_ref().map(|v| v.as_ref()), - profiler.clone(), + profiler.as_ref(), ) .await?; if let Some(out_stream) = out_stream.as_ref() { diff --git a/src/aggregator/processing.rs b/src/aggregator/processing.rs index 318524f..be17915 100644 --- a/src/aggregator/processing.rs +++ b/src/aggregator/processing.rs @@ -1,6 +1,6 @@ use super::group::GroupedMessages; use super::recovered::RecoveredMessages; -use super::report::report_measurements; +use super::report::MeasurementReporter; use super::AggregatorError; use crate::epoch::{is_epoch_expired, EpochConfig}; use crate::models::{DBConnection, DBPool, DBStorageConnections, PendingMessage, RecoveredMessage}; @@ -16,7 +16,7 @@ pub async fn process_expired_epochs( conn: Arc>, epoch_config: &EpochConfig, out_stream: Option<&DynRecordStream>, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), AggregatorError> { let epochs = RecoveredMessage::list_distinct_epochs(conn.clone()).await?; for epoch in epochs { @@ -29,15 +29,9 @@ pub async fn process_expired_epochs( .fetch_all_recovered_with_nonzero_count(conn.clone(), epoch as u8, profiler.clone()) .await?; - report_measurements( - &mut rec_msgs, - epoch_config, - epoch as u8, - true, - out_stream, - profiler.clone(), - ) - .await?; + MeasurementReporter::new(epoch_config, out_stream, profiler, epoch as u8, true) + .report(&mut rec_msgs) + .await?; RecoveredMessage::delete_epoch(conn.clone(), epoch, profiler.clone()).await?; PendingMessage::delete_epoch(conn.clone(), epoch, profiler.clone()).await?; } @@ -194,7 +188,7 @@ pub fn start_subtask( info!("Task {}: Deleting old pending messages", id); for (epoch, msg_tag) in pending_tags_to_remove { - PendingMessage::delete_tag(store_conns.get(), epoch as i16, msg_tag, profiler.clone()) + PendingMessage::delete_tag(store_conns.get(), epoch as i16, msg_tag, profiler.as_ref()) .await .unwrap(); } @@ -205,14 +199,14 @@ pub fn start_subtask( let rec_epochs: Vec = rec_msgs.map.keys().cloned().collect(); let mut measurements_count = 0; for epoch in rec_epochs { - measurements_count += report_measurements( - &mut rec_msgs, + measurements_count += MeasurementReporter::new( epoch_config.as_ref(), + out_stream.as_ref().map(|v| v.as_ref()), + profiler.as_ref(), epoch, false, - out_stream.as_ref().map(|v| v.as_ref()), - profiler.clone(), ) + .report(&mut rec_msgs) .await .unwrap(); } diff --git a/src/aggregator/recovered.rs b/src/aggregator/recovered.rs index a121b6e..54de1a2 100644 --- a/src/aggregator/recovered.rs +++ b/src/aggregator/recovered.rs @@ -64,7 +64,7 @@ impl RecoveredMessages { conn.clone(), epoch as i16, msg_tags.to_vec(), - profiler.clone(), + profiler.as_ref(), ) .await?; for rec_msg in recovered_msgs { @@ -78,7 +78,7 @@ impl RecoveredMessages { &mut self, conn: Arc>, epoch: u8, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), AggregatorError> { let recovered_msgs = RecoveredMessage::list_with_nonzero_count(conn, epoch as i16, profiler).await?; @@ -103,7 +103,7 @@ impl RecoveredMessages { store_conns.get(), rec_msg.id, rec_msg.count, - profiler.clone(), + profiler.as_ref(), ) .await?; } @@ -112,7 +112,7 @@ impl RecoveredMessages { for new_msgs in new_msgs.chunks(INSERT_BATCH_SIZE) { let new_msgs = new_msgs.to_vec(); new_msgs - .insert_batch(store_conns.get(), profiler.clone()) + .insert_batch(store_conns.get(), profiler.as_ref()) .await?; } Ok(()) @@ -260,7 +260,7 @@ mod tests { new_rec_msgs .clone() - .insert_batch(conn.clone(), profiler.clone()) + .insert_batch(conn.clone(), profiler.as_ref()) .await .unwrap(); @@ -268,7 +268,7 @@ mod tests { for epoch in 3..=4 { let mut rec_msg = - RecoveredMessage::list(conn.clone(), epoch, vec![vec![60; 20]], profiler.clone()) + RecoveredMessage::list(conn.clone(), epoch, vec![vec![60; 20]], profiler.as_ref()) .await .unwrap()[0] .clone(); @@ -287,7 +287,7 @@ mod tests { for epoch in 3..=4 { let rec_msg = - RecoveredMessage::list(conn.clone(), epoch, vec![vec![60; 20]], profiler.clone()) + RecoveredMessage::list(conn.clone(), epoch, vec![vec![60; 20]], profiler.as_ref()) .await .unwrap()[0] .clone(); diff --git a/src/aggregator/report.rs b/src/aggregator/report.rs index 3d1b94b..49c8427 100644 --- a/src/aggregator/report.rs +++ b/src/aggregator/report.rs @@ -7,9 +7,17 @@ use futures::future::{BoxFuture, FutureExt}; use serde_json::Value; use std::collections::HashMap; use std::str::from_utf8; -use std::sync::Arc; use std::time::Instant; +pub struct MeasurementReporter<'a> { + epoch_config: &'a EpochConfig, + out_stream: Option<&'a DynRecordStream>, + profiler: &'a Profiler, + epoch: u8, + epoch_start_date: String, + partial_report: bool, +} + fn build_full_measurement_json( metric_chain: Vec<(String, Value)>, epoch_date_field_name: &str, @@ -28,113 +36,102 @@ fn build_full_measurement_json( Ok(serde_json::to_vec(&full_measurement)?) } -fn report_measurements_recursive<'a>( - rec_msgs: &'a mut RecoveredMessages, - epoch: u8, - epoch_date_field_name: &'a str, - epoch_start_date: &'a str, - partial_report: bool, - out_stream: Option<&'a DynRecordStream>, - metric_chain: Vec<(String, Value)>, - parent_msg_tag: Option>, - profiler: Arc, -) -> BoxFuture<'a, Result> { - async move { - let tags = rec_msgs.get_tags_by_parent(epoch, parent_msg_tag); +impl<'a> MeasurementReporter<'a> { + pub fn new( + epoch_config: &'a EpochConfig, + out_stream: Option<&'a DynRecordStream>, + profiler: &'a Profiler, + epoch: u8, + partial_report: bool, + ) -> Self { + let epoch_start_date = get_epoch_survey_date(&epoch_config, epoch); + Self { + epoch_config, + out_stream, + profiler, + epoch, + epoch_start_date, + partial_report, + } + } - let mut recovered_count = 0; + fn report_recursive( + &'a self, + rec_msgs: &'a mut RecoveredMessages, + metric_chain: Vec<(String, Value)>, + parent_msg_tag: Option>, + ) -> BoxFuture<'a, Result> { + async move { + let tags = rec_msgs.get_tags_by_parent(self.epoch, parent_msg_tag); - for tag in tags { - let mut msg = rec_msgs.get_mut(epoch, &tag).unwrap().clone(); - if msg.count == 0 { - continue; - } + let mut recovered_count = 0; - let mut metric_chain = metric_chain.clone(); - metric_chain.push((msg.metric_name.clone(), msg.metric_value.clone().into())); + for tag in tags { + let mut msg = rec_msgs.get_mut(self.epoch, &tag).unwrap().clone(); + if msg.count == 0 { + continue; + } - // is_msmt_final: true if the current measurement should be reported right now - // i.e. all layers have been recovered - let is_msmt_final = if msg.has_children { - let children_rec_count = report_measurements_recursive( - rec_msgs, - epoch, - epoch_date_field_name, - epoch_start_date, - partial_report, - out_stream, - metric_chain.clone(), - Some(tag), - profiler.clone(), - ) - .await?; + let mut metric_chain = metric_chain.clone(); + metric_chain.push((msg.metric_name.clone(), msg.metric_value.clone().into())); - msg.count -= children_rec_count; + // is_msmt_final: true if the current measurement should be reported right now + // i.e. all layers have been recovered, or a partial report was requested for an old epoch + let is_msmt_final = if msg.has_children { + let children_rec_count = self + .report_recursive(rec_msgs, metric_chain.clone(), Some(tag)) + .await?; - if msg.count > 0 && partial_report { - // partial_report is typically true during an expired epoch report. - // If the count for the current tag is non-zero, and child tags cannot be recovered, - // report the partial measurements now. - true - } else { - recovered_count += children_rec_count; - false - } - } else { - // if there are no children, we have recovered all tags in the metric chain; - // we can safely report the final measurement - true - }; + msg.count -= children_rec_count; - if is_msmt_final { - recovered_count += msg.count; - let full_msmt = build_full_measurement_json( - metric_chain, - epoch_date_field_name, - epoch_start_date, - msg.count, - )?; - let start_instant = Instant::now(); - match out_stream { - Some(o) => o.queue_produce(full_msmt).await?, - None => println!("{}", from_utf8(&full_msmt)?), + if msg.count > 0 && self.partial_report { + // partial_report is typically true during an expired epoch report. + // If the count for the current tag is non-zero, and child tags cannot be recovered, + // report the partial measurements now. + true + } else { + recovered_count += children_rec_count; + false + } + } else { + // if there are no children, we have recovered all tags in the metric chain; + // we can safely report the final measurement + true }; - profiler - .record_range_time(ProfilerStat::OutStreamProduceTime, start_instant) - .await; - msg.count = 0; + + if is_msmt_final { + recovered_count += msg.count; + let full_msmt = build_full_measurement_json( + metric_chain, + &self.epoch_config.epoch_date_field_name, + &self.epoch_start_date, + msg.count, + )?; + let start_instant = Instant::now(); + match self.out_stream { + Some(o) => o.queue_produce(full_msmt).await?, + None => println!("{}", from_utf8(&full_msmt)?), + }; + self + .profiler + .record_range_time(ProfilerStat::OutStreamProduceTime, start_instant) + .await; + msg.count = 0; + } + rec_msgs.add(msg); } - rec_msgs.add(msg); - } - Ok(recovered_count) + Ok(recovered_count) + } + .boxed() } - .boxed() -} -pub async fn report_measurements( - rec_msgs: &mut RecoveredMessages, - epoch_config: &EpochConfig, - epoch: u8, - partial_report: bool, - out_stream: Option<&DynRecordStream>, - profiler: Arc, -) -> Result { - let epoch_start_date = get_epoch_survey_date(&epoch_config, epoch); - Ok( - report_measurements_recursive( - rec_msgs, - epoch, - &epoch_config.epoch_date_field_name, - &epoch_start_date, - partial_report, - out_stream, - Vec::new(), - None, - profiler, - ) - .await?, - ) + pub async fn report( + &'a self, + rec_msgs: &'a mut RecoveredMessages, + ) -> Result { + self.report_recursive(rec_msgs, Vec::new(), None).await + } } #[cfg(test)] @@ -164,7 +161,7 @@ mod tests { async fn full_report() { let record_stream = TestRecordStream::default(); let mut recovered_msgs = RecoveredMessages::default(); - let profiler = Arc::new(Profiler::default()); + let profiler = Profiler::default(); let new_rec_msgs = vec![ RecoveredMessage { @@ -227,16 +224,10 @@ mod tests { for rec_msg in new_rec_msgs { recovered_msgs.add(rec_msg); } - let rec_count = report_measurements( - &mut recovered_msgs, - &test_epoch_config(2), - 2, - false, - Some(&record_stream), - profiler, - ) - .await - .unwrap(); + let epoch_config = test_epoch_config(2); + let reporter = + MeasurementReporter::new(&epoch_config, Some(&record_stream), &profiler, 2, false); + let rec_count = reporter.report(&mut recovered_msgs).await.unwrap(); assert_eq!(rec_count, 17); let records = parse_and_sort_records(record_stream.records_produced.into_inner()); @@ -265,7 +256,7 @@ mod tests { async fn partial_report() { let record_stream = TestRecordStream::default(); let mut recovered_msgs = RecoveredMessages::default(); - let profiler = Arc::new(Profiler::default()); + let profiler = Profiler::default(); let new_rec_msgs = vec![ RecoveredMessage { @@ -328,16 +319,10 @@ mod tests { for rec_msg in new_rec_msgs { recovered_msgs.add(rec_msg); } - report_measurements( - &mut recovered_msgs, - &test_epoch_config(2), - 2, - true, - Some(&record_stream), - profiler, - ) - .await - .unwrap(); + let epoch_config = test_epoch_config(2); + let reporter = + MeasurementReporter::new(&epoch_config, Some(&record_stream), &profiler, 2, true); + reporter.report(&mut recovered_msgs).await.unwrap(); let records = parse_and_sort_records(record_stream.records_produced.into_inner()); diff --git a/src/models/mod.rs b/src/models/mod.rs index a168fc5..386139e 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -144,7 +144,7 @@ pub trait BatchInsert { async fn insert_batch( self, conn: Arc>, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError>; } diff --git a/src/models/pending_msg.rs b/src/models/pending_msg.rs index 8cfdfda..1e1baa4 100644 --- a/src/models/pending_msg.rs +++ b/src/models/pending_msg.rs @@ -30,7 +30,7 @@ impl PendingMessage { conn: Arc>, filter_epoch_tag: i16, filter_msg_tag: Vec, - profiler: Arc, + profiler: &Profiler, ) -> Result, PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -53,7 +53,7 @@ impl PendingMessage { pub async fn delete_epoch( conn: Arc>, filter_epoch_tag: i16, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -74,7 +74,7 @@ impl PendingMessage { conn: Arc>, filter_epoch_tag: i16, filter_msg_tag: Vec, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -101,7 +101,7 @@ impl BatchInsert for Vec { async fn insert_batch( self, conn: Arc>, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { diff --git a/src/models/recovered_msg.rs b/src/models/recovered_msg.rs index fc7142e..ae29a96 100644 --- a/src/models/recovered_msg.rs +++ b/src/models/recovered_msg.rs @@ -55,7 +55,7 @@ impl RecoveredMessage { conn: Arc>, filter_epoch_tag: i16, filter_msg_tags: Vec>, - profiler: Arc, + profiler: &Profiler, ) -> Result, PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -80,7 +80,7 @@ impl RecoveredMessage { conn: Arc>, curr_id: i64, new_count: i64, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -103,7 +103,7 @@ impl RecoveredMessage { pub async fn list_with_nonzero_count( conn: Arc>, filter_epoch_tag: i16, - profiler: Arc, + profiler: &Profiler, ) -> Result, PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -143,7 +143,7 @@ impl RecoveredMessage { pub async fn delete_epoch( conn: Arc>, filter_epoch_tag: i16, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || { @@ -167,7 +167,7 @@ impl BatchInsert for Vec { async fn insert_batch( self, conn: Arc>, - profiler: Arc, + profiler: &Profiler, ) -> Result<(), PgStoreError> { let start_instant = Instant::now(); let result = task::spawn_blocking(move || {