Skip to content

Commit

Permalink
feat: Add timeout options and stats to Badger (#496)
Browse files Browse the repository at this point in the history
I'm adding a couple of statistics and a new timeout option to badger.
I'd like to use these in a paper I am writing -- having them upstreamed
would make reproducing these results easier.

That being said, if you are against tracking these in every future run
of badger I am happy to close this without merging.
  • Loading branch information
lmondada authored Jul 24, 2024
1 parent 4052714 commit 32a9885
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 3 deletions.
9 changes: 9 additions & 0 deletions badger-optimiser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ struct CmdLineArgs {
help = "Maximum time in seconds to wait between circuit improvements (default=None)."
)]
progress_timeout: Option<u64>,
/// Maximum number of circuits to process (default=no limit)
#[arg(
short = 'c',
long,
value_name = "MAX_CIRCUIT_CNT",
help = "Maximum number of circuits to process (default=None)."
)]
max_circuit_cnt: Option<usize>,
/// Number of threads (default=1)
#[arg(
short = 'j',
Expand Down Expand Up @@ -168,6 +176,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
n_threads,
split_circuit: opts.split_circ,
queue_size: opts.queue_size,
max_circuit_cnt: opts.max_circuit_cnt,
},
);

Expand Down
12 changes: 12 additions & 0 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ impl PyBadgerOptimiser {
/// If `None` the optimiser will run indefinitely, or until `timeout` is
/// reached.
///
/// * `max_circuit_cnt`: The maximum number of circuits to process before
/// stopping the optimisation.
///
///
/// For data parallel multi-threading, (split_circuit=true), applies on
/// a per-thread basis, otherwise applies globally.
///
/// If `None` the optimiser will run indefinitely, or until `timeout` is
/// reached.
///
/// * `n_threads`: The number of threads to use. Defaults to `1`.
///
/// * `split_circ`: Whether to split the circuit into chunks and process
Expand All @@ -84,6 +94,7 @@ impl PyBadgerOptimiser {
circ: &Bound<'py, PyAny>,
timeout: Option<u64>,
progress_timeout: Option<u64>,
max_circuit_cnt: Option<usize>,
n_threads: Option<NonZeroUsize>,
split_circ: Option<bool>,
queue_size: Option<usize>,
Expand All @@ -92,6 +103,7 @@ impl PyBadgerOptimiser {
let options = BadgerOptions {
timeout,
progress_timeout,
max_circuit_cnt,
n_threads: n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()),
split_circuit: split_circ.unwrap_or(false),
queue_size: queue_size.unwrap_or(100),
Expand Down
36 changes: 35 additions & 1 deletion tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ pub struct BadgerOptions {
///
/// Defaults to `None`, which means no timeout.
pub progress_timeout: Option<u64>,
/// The maximum number of circuits to process before stopping the optimisation.
///
/// For data parallel multi-threading, (split_circuit=true), applies on a
/// per-thread basis, otherwise applies globally.
///
/// Defaults to `None`, which means no limit.
pub max_circuit_cnt: Option<usize>,
/// The number of threads to use.
///
/// Defaults to `1`.
Expand Down Expand Up @@ -79,6 +86,7 @@ impl Default for BadgerOptions {
n_threads: NonZeroUsize::new(1).unwrap(),
split_circuit: Default::default(),
queue_size: 20,
max_circuit_cnt: None,
}
}
}
Expand Down Expand Up @@ -201,6 +209,7 @@ where
circ_cnt += 1;

let rewrites = self.rewriter.get_rewrites(&circ);
logger.register_branching_factor(rewrites.len());

// Get combinations of rewrites that can be applied to the circuit,
// and filter them to keep only the ones that
Expand Down Expand Up @@ -242,6 +251,12 @@ where
break;
}
}
if let Some(max_circuit_cnt) = opt.max_circuit_cnt {
if seen_hashes.len() >= max_circuit_cnt {
timeout_flag = true;
break;
}
}
}

logger.log_processing_end(
Expand All @@ -250,6 +265,7 @@ where
best_circ_cost,
false,
timeout_flag,
start_time.elapsed(),
);
best_circ
}
Expand All @@ -265,6 +281,7 @@ where
mut logger: BadgerLogger,
opt: BadgerOptions,
) -> Circuit {
let start_time = Instant::now();
let n_threads: usize = opt.n_threads.get();
let circ = circ.to_owned();

Expand Down Expand Up @@ -330,6 +347,14 @@ where
Ok(PriorityChannelLog::CircuitCount{processed_count: proc, seen_count: seen, queue_length}) => {
processed_count = proc;
seen_count = seen;
if let Some(max_circuit_cnt) = opt.max_circuit_cnt {
if seen_count > max_circuit_cnt {
timeout_flag = true;
// Signal the workers to stop.
let _ = pq.close();
break;
}
}
logger.log_progress(processed_count, Some(queue_length), seen_count);
}
Err(crossbeam_channel::RecvError) => {
Expand Down Expand Up @@ -382,6 +407,7 @@ where
best_circ_cost,
true,
timeout_flag,
start_time.elapsed(),
);

joins.into_iter().for_each(|j| j.join().unwrap());
Expand All @@ -399,6 +425,7 @@ where
mut logger: BadgerLogger,
opt: BadgerOptions,
) -> Result<Circuit, HugrError> {
let start_time = Instant::now();
let circ = circ.to_owned();
let circ_cost = self.cost(&circ);
let max_chunk_cost = circ_cost.clone().div_cost(opt.n_threads);
Expand Down Expand Up @@ -453,7 +480,14 @@ where
logger.log_best(best_circ_cost.clone(), num_rewrites);
}

logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false);
logger.log_processing_end(
opt.n_threads.get(),
None,
best_circ_cost,
true,
false,
start_time.elapsed(),
);
joins.into_iter().for_each(|j| j.join().unwrap());

Ok(best_circ)
Expand Down
56 changes: 54 additions & 2 deletions tket2/src/optimiser/badger/log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct BadgerLogger<'w> {
circ_candidates_csv: Option<csv::Writer<Box<dyn io::Write + 'w>>>,
last_circ_processed: usize,
last_progress_time: Instant,
branching_factor: UsizeAverage,
}

impl<'w> Default for BadgerLogger<'w> {
Expand All @@ -17,6 +18,7 @@ impl<'w> Default for BadgerLogger<'w> {
last_circ_processed: Default::default(),
// Ensure the first progress message is printed.
last_progress_time: Instant::now() - Duration::from_secs(60),
branching_factor: UsizeAverage::new(),
}
}
}
Expand Down Expand Up @@ -75,17 +77,22 @@ impl<'w> BadgerLogger<'w> {
best_cost: C,
needs_joining: bool,
timeout: bool,
elapsed_time: Duration,
) {
let elapsed_secs = elapsed_time.as_secs_f32();
match timeout {
true => self.log("Optimisation finished (timeout)."),
false => self.log("Optimisation finished."),
true => self.log(format!(
"Optimisation finished in {elapsed_secs:.2}s (timeout)."
)),
false => self.log(format!("Optimisation finished in {elapsed_secs:.2}s.")),
};
match circuits_seen {
Some(circuits_seen) => self.log(format!(
"Processed {circuits_processed} circuits (out of {circuits_seen} seen)."
)),
None => self.log(format!("Processed {circuits_processed} circuits.")),
}
self.log_avg_branching_factor();
self.log(format!("---- END RESULT: {:?} ----", best_cost));
if needs_joining {
self.log("Joining worker threads.");
Expand Down Expand Up @@ -120,11 +127,29 @@ impl<'w> BadgerLogger<'w> {
tracing::info!(target: LOG_TARGET, "{}", msg.as_ref());
}

/// Log a warning message.
#[inline]
pub fn warn(&self, msg: impl AsRef<str>) {
tracing::warn!(target: LOG_TARGET, "{}", msg.as_ref());
}

/// Log verbose information on the progress of the optimization.
#[inline]
pub fn progress(&self, msg: impl AsRef<str>) {
tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref());
}

/// Append a new branching factor to the average.
pub fn register_branching_factor(&mut self, branching_factor: usize) {
self.branching_factor.append(branching_factor);
}

/// Log the average branching factor so far.
pub fn log_avg_branching_factor(&self) {
if let Some(avg) = self.branching_factor.average() {
self.log(format!("Average branching factor: {}", avg));
}
}
}

/// A helper struct for logging improvements in circuit size seen during the
Expand All @@ -143,3 +168,30 @@ impl<C> BestCircSer<C> {
Self { circ_cost, time }
}
}

struct UsizeAverage {
sum: usize,
count: usize,
}

impl UsizeAverage {
pub fn new() -> Self {
Self {
sum: Default::default(),
count: 0,
}
}

pub fn append(&mut self, value: usize) {
self.sum += value;
self.count += 1;
}

pub fn average(&self) -> Option<f64> {
if self.count > 0 {
Some(self.sum as f64 / self.count as f64)
} else {
None
}
}
}

0 comments on commit 32a9885

Please sign in to comment.