diff --git a/badger-optimiser/src/main.rs b/badger-optimiser/src/main.rs index 0c9b56af..399a5d7f 100644 --- a/badger-optimiser/src/main.rs +++ b/badger-optimiser/src/main.rs @@ -81,6 +81,14 @@ struct CmdLineArgs { help = "Maximum time in seconds to wait between circuit improvements (default=None)." )] progress_timeout: Option, + /// Maximum number of circuits to process (default=no limit) + #[arg( + short = 'c', + long, + value_name = "MAX_CIRCUIT_CNT", + help = "Maximum number of circuits to process (default=None)." + )] + max_circuit_cnt: Option, /// Number of threads (default=1) #[arg( short = 'j', @@ -168,6 +176,7 @@ fn main() -> Result<(), Box> { n_threads, split_circuit: opts.split_circ, queue_size: opts.queue_size, + max_circuit_cnt: opts.max_circuit_cnt, }, ); diff --git a/tket2-py/src/optimiser.rs b/tket2-py/src/optimiser.rs index a311780c..e5bb6cde 100644 --- a/tket2-py/src/optimiser.rs +++ b/tket2-py/src/optimiser.rs @@ -61,6 +61,16 @@ impl PyBadgerOptimiser { /// If `None` the optimiser will run indefinitely, or until `timeout` is /// reached. /// + /// * `max_circuit_cnt`: The maximum number of circuits to process before + /// stopping the optimisation. + /// + /// + /// For data parallel multi-threading, (split_circuit=true), applies on + /// a per-thread basis, otherwise applies globally. + /// + /// If `None` the optimiser will run indefinitely, or until `timeout` is + /// reached. + /// /// * `n_threads`: The number of threads to use. Defaults to `1`. /// /// * `split_circ`: Whether to split the circuit into chunks and process @@ -84,6 +94,7 @@ impl PyBadgerOptimiser { circ: &Bound<'py, PyAny>, timeout: Option, progress_timeout: Option, + max_circuit_cnt: Option, n_threads: Option, split_circ: Option, queue_size: Option, @@ -92,6 +103,7 @@ impl PyBadgerOptimiser { let options = BadgerOptions { timeout, progress_timeout, + max_circuit_cnt, n_threads: n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), split_circuit: split_circ.unwrap_or(false), queue_size: queue_size.unwrap_or(100), diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index 3daeb68f..4359166f 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -51,6 +51,13 @@ pub struct BadgerOptions { /// /// Defaults to `None`, which means no timeout. pub progress_timeout: Option, + /// The maximum number of circuits to process before stopping the optimisation. + /// + /// For data parallel multi-threading, (split_circuit=true), applies on a + /// per-thread basis, otherwise applies globally. + /// + /// Defaults to `None`, which means no limit. + pub max_circuit_cnt: Option, /// The number of threads to use. /// /// Defaults to `1`. @@ -79,6 +86,7 @@ impl Default for BadgerOptions { n_threads: NonZeroUsize::new(1).unwrap(), split_circuit: Default::default(), queue_size: 20, + max_circuit_cnt: None, } } } @@ -201,6 +209,7 @@ where circ_cnt += 1; let rewrites = self.rewriter.get_rewrites(&circ); + logger.register_branching_factor(rewrites.len()); // Get combinations of rewrites that can be applied to the circuit, // and filter them to keep only the ones that @@ -242,6 +251,12 @@ where break; } } + if let Some(max_circuit_cnt) = opt.max_circuit_cnt { + if seen_hashes.len() >= max_circuit_cnt { + timeout_flag = true; + break; + } + } } logger.log_processing_end( @@ -250,6 +265,7 @@ where best_circ_cost, false, timeout_flag, + start_time.elapsed(), ); best_circ } @@ -265,6 +281,7 @@ where mut logger: BadgerLogger, opt: BadgerOptions, ) -> Circuit { + let start_time = Instant::now(); let n_threads: usize = opt.n_threads.get(); let circ = circ.to_owned(); @@ -330,6 +347,14 @@ where Ok(PriorityChannelLog::CircuitCount{processed_count: proc, seen_count: seen, queue_length}) => { processed_count = proc; seen_count = seen; + if let Some(max_circuit_cnt) = opt.max_circuit_cnt { + if seen_count > max_circuit_cnt { + timeout_flag = true; + // Signal the workers to stop. + let _ = pq.close(); + break; + } + } logger.log_progress(processed_count, Some(queue_length), seen_count); } Err(crossbeam_channel::RecvError) => { @@ -382,6 +407,7 @@ where best_circ_cost, true, timeout_flag, + start_time.elapsed(), ); joins.into_iter().for_each(|j| j.join().unwrap()); @@ -399,6 +425,7 @@ where mut logger: BadgerLogger, opt: BadgerOptions, ) -> Result { + let start_time = Instant::now(); let circ = circ.to_owned(); let circ_cost = self.cost(&circ); let max_chunk_cost = circ_cost.clone().div_cost(opt.n_threads); @@ -453,7 +480,14 @@ where logger.log_best(best_circ_cost.clone(), num_rewrites); } - logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false); + logger.log_processing_end( + opt.n_threads.get(), + None, + best_circ_cost, + true, + false, + start_time.elapsed(), + ); joins.into_iter().for_each(|j| j.join().unwrap()); Ok(best_circ) diff --git a/tket2/src/optimiser/badger/log.rs b/tket2/src/optimiser/badger/log.rs index e7201d66..116b7ad4 100644 --- a/tket2/src/optimiser/badger/log.rs +++ b/tket2/src/optimiser/badger/log.rs @@ -8,6 +8,7 @@ pub struct BadgerLogger<'w> { circ_candidates_csv: Option>>, last_circ_processed: usize, last_progress_time: Instant, + branching_factor: UsizeAverage, } impl<'w> Default for BadgerLogger<'w> { @@ -17,6 +18,7 @@ impl<'w> Default for BadgerLogger<'w> { last_circ_processed: Default::default(), // Ensure the first progress message is printed. last_progress_time: Instant::now() - Duration::from_secs(60), + branching_factor: UsizeAverage::new(), } } } @@ -75,10 +77,14 @@ impl<'w> BadgerLogger<'w> { best_cost: C, needs_joining: bool, timeout: bool, + elapsed_time: Duration, ) { + let elapsed_secs = elapsed_time.as_secs_f32(); match timeout { - true => self.log("Optimisation finished (timeout)."), - false => self.log("Optimisation finished."), + true => self.log(format!( + "Optimisation finished in {elapsed_secs:.2}s (timeout)." + )), + false => self.log(format!("Optimisation finished in {elapsed_secs:.2}s.")), }; match circuits_seen { Some(circuits_seen) => self.log(format!( @@ -86,6 +92,7 @@ impl<'w> BadgerLogger<'w> { )), None => self.log(format!("Processed {circuits_processed} circuits.")), } + self.log_avg_branching_factor(); self.log(format!("---- END RESULT: {:?} ----", best_cost)); if needs_joining { self.log("Joining worker threads."); @@ -120,11 +127,29 @@ impl<'w> BadgerLogger<'w> { tracing::info!(target: LOG_TARGET, "{}", msg.as_ref()); } + /// Log a warning message. + #[inline] + pub fn warn(&self, msg: impl AsRef) { + tracing::warn!(target: LOG_TARGET, "{}", msg.as_ref()); + } + /// Log verbose information on the progress of the optimization. #[inline] pub fn progress(&self, msg: impl AsRef) { tracing::info!(target: PROGRESS_TARGET, "{}", msg.as_ref()); } + + /// Append a new branching factor to the average. + pub fn register_branching_factor(&mut self, branching_factor: usize) { + self.branching_factor.append(branching_factor); + } + + /// Log the average branching factor so far. + pub fn log_avg_branching_factor(&self) { + if let Some(avg) = self.branching_factor.average() { + self.log(format!("Average branching factor: {}", avg)); + } + } } /// A helper struct for logging improvements in circuit size seen during the @@ -143,3 +168,30 @@ impl BestCircSer { Self { circ_cost, time } } } + +struct UsizeAverage { + sum: usize, + count: usize, +} + +impl UsizeAverage { + pub fn new() -> Self { + Self { + sum: Default::default(), + count: 0, + } + } + + pub fn append(&mut self, value: usize) { + self.sum += value; + self.count += 1; + } + + pub fn average(&self) -> Option { + if self.count > 0 { + Some(self.sum as f64 / self.count as f64) + } else { + None + } + } +}