diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1a387b77..61ac7736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,7 +8,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: setup - run: rustup default 1.67 && rustup component add clippy + run: rustup default 1.72 && rustup component add clippy - name: lint run: cargo clippy --version && cargo clippy --all-targets --all-features --tests --no-deps -- -D warnings - name: build @@ -18,7 +18,7 @@ jobs: steps: - uses: actions/checkout@v3 - name: setup - run: rustup default 1.67 + run: rustup default 1.72 - name: test run: cargo --version && cargo test --release format: @@ -26,6 +26,6 @@ jobs: steps: - uses: actions/checkout@v3 - name: setup - run: rustup default 1.67 && rustup component add rustfmt + run: rustup default 1.72 && rustup component add rustfmt - name: check formatting run: cargo fmt --version && cargo fmt --check diff --git a/Cargo.lock b/Cargo.lock index 938f6a3a..cf629d78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,18 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "ahash" -version = "0.8.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" -dependencies = [ - "cfg-if", - "getrandom", - "once_cell", - "version_check", -] - [[package]] name = "ansi_term" version = "0.12.1" @@ -54,9 +42,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "carcara" -version = "1.0.0" +version = "1.1.0" dependencies = [ - "ahash", + "indexmap 2.0.0", "log", "rand", "rug", @@ -66,11 +54,9 @@ dependencies = [ [[package]] name = "carcara-cli" -version = "1.0.0" +version = "1.1.0" dependencies = [ - "ahash", "ansi_term", - "atty", "carcara", "clap", "const_format", @@ -87,15 +73,15 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "3.2.23" +version = "3.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5" +checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" dependencies = [ "atty", "bitflags", "clap_derive", "clap_lex", - "indexmap", + "indexmap 1.9.3", "once_cell", "strsim", "termcolor", @@ -104,9 +90,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "3.2.18" +version = "3.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea0c8bce528c4be4da13ea6fead8965e95b6073585a2f05204bd8f4119f82a65" +checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008" dependencies = [ "heck", "proc-macro-error", @@ -126,18 +112,18 @@ dependencies = [ [[package]] name = "const_format" -version = "0.2.30" +version = "0.2.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7309d9b4d3d2c0641e018d449232f2e28f1b22933c137f157d3dbc14228b8c0e" +checksum = "c990efc7a285731f9a4378d81aff2f0e85a2c8781a05ef0f8baa8dac54d0ff48" dependencies = [ "const_format_proc_macros", ] [[package]] name = "const_format_proc_macros" -version = "0.2.29" +version = "0.2.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d897f47bf7270cf70d370f8f98c1abb6d2d4cf60a6845d30e05bfb90c6568650" +checksum = "e026b6ce194a874cb9cf32cd5772d1ef9767cc8fcb5765948d74f37a9d8b2bf6" dependencies = [ "proc-macro2", "quote", @@ -156,18 +142,24 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.15" +version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "getrandom" -version = "0.2.9" +version = "0.2.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" dependencies = [ "cfg-if", "libc", @@ -198,9 +190,9 @@ dependencies = [ [[package]] name = "gmp-mpfr-sys" -version = "1.5.2" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b560063e2ffa8ce9c2ef9bf487f2944a97deca5b8de0b5bcd0ae6437ef8b75f" +checksum = "19c5c67d8c29fe87e3266e691dd60948e6e4df4496c53355ef3551142945721b" dependencies = [ "libc", "windows-sys", @@ -212,6 +204,12 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +[[package]] +name = "hashbrown" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" + [[package]] name = "heck" version = "0.4.1" @@ -234,35 +232,42 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", +] + +[[package]] +name = "indexmap" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +dependencies = [ + "equivalent", + "hashbrown 0.14.0", ] [[package]] name = "libc" -version = "0.2.142" +version = "0.2.147" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a987beff54b60ffa6d51982e1aa1146bc42f19bd26be28b0586f252fccf5317" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "log" -version = "0.4.17" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e" -dependencies = [ - "cfg-if", -] +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "once_cell" -version = "1.17.1" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "os_str_bytes" -version = "6.5.0" +version = "6.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" +checksum = "4d5d9eb14b174ee9aa2ef96dc2b94637a2d4b6e7cb873c7e171f0c20c6cf3eac" [[package]] name = "ppv-lite86" @@ -302,18 +307,18 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.56" +version = "1.0.66" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.26" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] @@ -350,9 +355,9 @@ dependencies = [ [[package]] name = "rug" -version = "1.19.2" +version = "1.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "555e8b44763d034526db899c88cd56ccc4486cd38b444c8aa0e79d4e70ae5a34" +checksum = "8882d6fd62b334b72dcf5c79f7e6b529d6790322de14bb49339415266131b031" dependencies = [ "az", "gmp-mpfr-sys", @@ -387,9 +392,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.15" +version = "2.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a34fcf3e8b60f57e6a14301a2e916d323af98b0ea63c599441eec8558660c822" +checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" dependencies = [ "proc-macro2", "quote", @@ -423,29 +428,29 @@ checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" [[package]] name = "thiserror" -version = "1.0.40" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.40" +version = "1.0.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.15", + "syn 2.0.29", ] [[package]] name = "unicode-ident" -version = "1.0.8" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5464a87b239f13a63a501f2701565754bae92d243d4bb7eb12f6d57d2269bf4" +checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" [[package]] name = "unicode-xid" diff --git a/Cargo.toml b/Cargo.toml index 1024221f..867dbc0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] members = ["carcara", "cli", "test-generator"] +resolver = "2" [profile.release] debug = 1 diff --git a/README.md b/README.md index 3f42b040..6b599b51 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Carcara is a proof checker and elaborator for SMT proofs in the [Alethe format]( ## Building -To build Carcara, you will need Rust and Cargo 1.67 or newer. Build the project with `cargo build`. +To build Carcara, you will need Rust and Cargo 1.72 or newer. Build the project with `cargo build`. When running on large proofs, we recommend compiling with optimizations enabled: `cargo build --release`. @@ -17,12 +17,12 @@ the project with all optimizations enabled, and install the CLI binary in `$HOME To check a proof file, use the `check` command, passing both the proof file and the original SMT-LIB problem file. ``` -carcara check example.smt2.proof example.smt2 +carcara check example.smt2.alethe example.smt2 ``` -If the problem file name is exactly the proof file name minus `.proof`, you can omit it: +If the problem file name is exactly the proof file name minus `.alethe`, you can omit it: ``` -carcara check example.smt2.proof +carcara check example.smt2.alethe ``` By default, Carcara will return a checking error when encountering a rule it does not recognize. If @@ -37,31 +37,61 @@ See `carcara help check` for more options. You can elaborate a proof file using the `elaborate` command. ``` -carcara elaborate example.smt2.proof example.smt2 +carcara elaborate example.smt2.alethe example.smt2 ``` This command will check the given proof while elaborating it, and print the elaborated proof to standard output. The `--print-with-sharing` flag controls whether the elaborated proof will be printed using term sharing. -By default, elaboration of `lia_generic` steps using cvc5 is disabled. To enable it, pass the -`--lia-via-cvc5` flag. You will need to have a working binary of cvc5 in your PATH. - Many of the same flags used in the `check` command also apply to the `elaborate` command. See `carcara help elaborate` for more details. +### `lia_generic` steps + +By default, Carcara ignores steps of the `lia_generic` rule when checking or elaborating a proof, +instead considering them as holes. However, you can use an external solver to aid Carcara in +checking these steps, using the `--lia-solver` option. For example, running +``` +carcara check example.smt2.alethe --lia-solver cvc5 +``` + +will check the proof using cvc5 (more precisely, the cvc5 binary in your `PATH`) to check any +`lia_generic` steps. This is done by converting the `lia_generic` step into an SMT-LIB problem, +giving it to the solver, and checking the Alethe proof that the solver produces. If instead of just +checking we were also elaborating the proof, this would also insert the solver proof in the place of +the `lia_generic` step. + +The value given to `--lia-solver` should be the path of the solver binary. Conceivably, any solver +can be used (SMT or otherwise) as long as it is able to read SMT-LIB from stdin, solve the linear +integer arithmetic problem, and output an Alethe proof to stdout. + +The `--lia-solver-args` option can be used to change the arguments passed to the solver binary. This +option should receive a single value, where multiple arguments are separated by spaces. For example, +if you wanted to instead check `lia_generic` steps using veriT, you might pass the following +arguments: +``` +carcara check example.smt2.alethe --lia-solver veriT --lia-solver-args "--proof=- --proof-with-sharing" +``` + +The default arguments for `--lia-solver-args` are as follows (note that they assume you use cvc5 as +a solver): +``` +--tlimit=10000 --lang=smt2 --proof-format-mode=alethe --proof-granularity=theory-rewrite --proof-alethe-res-pivots +``` + ### Running benchmarks The `bench` command is used to run benchmarks. For example, the following command will run a benchmark on three proof files. ``` -carcara bench a.smt2.proof b.smt2.proof c.smt2.proof +carcara bench a.smt2.alethe b.smt2.alethe c.smt2.alethe ``` The command takes as arguments any number of proof files or directories. If a directory is passed, -the benchmark will be run on all `.proof` files in that directory. This command assumes that the +the benchmark will be run on all `.alethe` files in that directory. This command assumes that the problem file associated with each proof is in the same directory as the proof, and that they follow -the pattern `.smt2`/`.smt2.proof`. +the pattern `.smt2`/`.smt2.alethe`. The benchmark will parse and check each file, and record performance data. If you pass the `--elaborate` flag, the proofs will also be elaborated (though the resulting elaborated proof is @@ -76,7 +106,6 @@ enable multiple threads using the `-j`/`--num-threads` option. See `carcara help bench` for more options. - ## "Strict" checking Strict checking mode can be enabled by using the `--strict` flag when checking. Currently, this only diff --git a/carcara/Cargo.toml b/carcara/Cargo.toml index 7af5dd7c..8d84176a 100644 --- a/carcara/Cargo.toml +++ b/carcara/Cargo.toml @@ -1,16 +1,16 @@ [package] name = "carcara" -version = "1.0.0" -authors = ["Bruno Andreotti "] +version = "1.1.0" +authors = ["Bruno Andreotti ", "Vinícius Braga Freire "] edition = "2021" -rust-version = "1.67" +rust-version = "1.72" license = "Apache-2.0" [dependencies] -ahash = "0.8.3" -log = "0.4.17" -rug = { version = "1.19.2", features = ["integer", "rational"] } -thiserror = "1.0.40" +indexmap = "2.0.0" +log = "0.4.20" +rug = { version = "1.21.0", features = ["integer", "rational"] } +thiserror = "1.0.47" [dev-dependencies] test-generator = { path = "../test-generator" } diff --git a/carcara/src/ast/context.rs b/carcara/src/ast/context.rs new file mode 100644 index 00000000..f967a923 --- /dev/null +++ b/carcara/src/ast/context.rs @@ -0,0 +1,285 @@ +use crate::ast::*; +use std::sync::{atomic::AtomicUsize, Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; + +pub struct Context { + pub mappings: Vec<(Rc, Rc)>, + pub bindings: IndexSet, + pub cumulative_substitution: Option, +} + +/// A tuple that will represent a single `Context` and allows a `Context` to be shared between threads. +/// +/// `0`: Number of threads that will use this context. +/// +/// `1`: Shareable and droppable slot for the context. +type ContextInfo = (AtomicUsize, RwLock>); + +#[derive(Default)] +/// Struct that implements a thread-shared context stack. That way, this stack +/// tries to use an already existing global `Context` (built by another thread). +/// If it was still not built, then the current thread is going to build this +/// context so other threads can also use it. +pub struct ContextStack { + /// The context vector that is shared globally between all the threads. + /// The contexts storage is index based (which the index of each context is + /// defined by the anchor/subproof id obtained in the parser). + context_vec: Arc>, + /// The stack of contexts id (works just like a map to `context_vec`). + stack: Vec, + num_cumulative_calculated: usize, +} + +impl ContextStack { + pub fn new() -> Self { + Default::default() + } + + /// Creates an empty stack from contexts usage info (a vector indicating how + /// many threads are going to use each context). + pub fn from_usage(context_usage: &Vec) -> Self { + let mut context_vec: Arc> = Arc::new(vec![]); + let ctx_ref = Arc::get_mut(&mut context_vec).unwrap(); + + for &usage in context_usage { + ctx_ref.push((AtomicUsize::new(usage), RwLock::new(None))); + } + + Self { + context_vec, + stack: vec![], + num_cumulative_calculated: 0, + } + } + + /// Creates an empty stack from a previous stack (starts with context infos + /// already instantiated). + pub fn from_previous(&self) -> Self { + Self { + context_vec: self.context_vec.clone(), + stack: vec![], + num_cumulative_calculated: 0, + } + } + + pub fn len(&self) -> usize { + self.stack.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn last(&self) -> Option>> { + self.stack + .last() + .map(|id| self.context_vec[*id].1.read().unwrap()) + } + + pub fn last_mut(&mut self) -> Option>> { + self.stack + .last_mut() + .map(|id| self.context_vec[*id].1.write().unwrap()) + } + + /// A function used to force the creation of a new context at the end of the + /// `context_vec`. This function should be called before a + /// `ContextStack::push` in a single thread operation. Since a single + /// thread doesn't require a schedule balancing, then there is no info about + /// how many contexts there are in the proof (and it's not needed since we + /// can always add a new context at the end of the vector just like an usual + /// stack) + pub fn force_new_context(&mut self) -> usize { + let ctx_vec = Arc::get_mut(&mut self.context_vec).unwrap(); + ctx_vec.push((AtomicUsize::new(1), RwLock::new(None))); + ctx_vec.len() - 1 + } + + pub fn push( + &mut self, + pool: &mut dyn TermPool, + assignment_args: &[(String, Rc)], + variable_args: &[SortedVar], + context_id: usize, + ) -> Result<(), SubstitutionError> { + // The write guard was yielded to this thread + if let Ok(mut ctx_write_guard) = self.context_vec[context_id].1.try_write() { + // It's the first thread trying to build this context. It will + // build this context at the context vec (accessible for all threads) + if ctx_write_guard.is_none() { + // Since some rules (like `refl`) need to apply substitutions until a fixed point, we + // precompute these substitutions into a separate hash map. This assumes that the assignment + // arguments are in the correct order. + let mut substitution = Substitution::empty(); + let mut substitution_until_fixed_point = Substitution::empty(); + + // We build the `substitution_until_fixed_point` hash map from the bottom up, by using the + // substitutions already introduced to transform the result of a new substitution before + // inserting it into the hash map. So for instance, if the substitutions are `(:= y z)` and + // `(:= x (f y))`, we insert the first substitution, and then, when introducing the second, + // we use the current state of the hash map to transform `(f y)` into `(f z)`. The + // resulting hash map will then contain `(:= y z)` and `(:= x (f z))` + for (var, value) in assignment_args { + let var_term = Term::new_var(var, pool.sort(value)); + let var_term = pool.add(var_term); + substitution.insert(pool, var_term.clone(), value.clone())?; + let new_value = substitution_until_fixed_point.apply(pool, value); + substitution_until_fixed_point.insert(pool, var_term, new_value)?; + } + + let mappings = assignment_args + .iter() + .map(|(var, value)| { + let var_term = (var.clone(), pool.sort(value)).into(); + (pool.add(var_term), value.clone()) + }) + .collect(); + let bindings = variable_args.iter().cloned().collect(); + // Finally creates the new context under this RwLock + *ctx_write_guard = Some(Context { + mappings, + bindings, + cumulative_substitution: None, + }); + } + } + // Adds this context in the stack + // Notice that even though the context is not ready for use, the write + // guard is still being held by some thread, then if this context is + // required at any moment, then we are assured it will wait until the + // fully context construction + self.stack.push(context_id); + Ok(()) + } + + pub fn pop(&mut self) { + use std::sync::atomic::Ordering; + + if let Some(id) = self.stack.pop() { + let this_context = &self.context_vec[id]; + + let mut remaining_threads = this_context.0.load(Ordering::Acquire); + remaining_threads = remaining_threads + .checked_sub(1) + .expect("A thread tried to access a context not allocated for it."); + + if remaining_threads == 0 { + // Drop this context since the last thread stopped using it + *this_context.1.write().unwrap() = None; + } + this_context.0.store(remaining_threads, Ordering::Release); + } + + self.num_cumulative_calculated = + std::cmp::min(self.num_cumulative_calculated, self.stack.len()); + } + + fn catch_up_cumulative(&mut self, pool: &mut dyn TermPool, up_to: usize) { + for i in self.num_cumulative_calculated..std::cmp::max(up_to + 1, self.len()) { + // Requires read guard. Since the i-th context will be mutated far + // below this line, we first take the read guard here and then, when + // necessary, we require the write guard. This tries to avoid bigger + // overheads + let context_guard = self.context_vec[self.stack[i]].1.read().unwrap(); + let curr_context = context_guard.as_ref().unwrap(); + + let simultaneous = build_simultaneous_substitution(pool, &curr_context.mappings).map; + let mut cumulative_substitution = simultaneous.clone(); + + if i > 0 { + // Waits until OS allows to read this previous context. The code structure + // makes sure that this context, when released for reading, will be already + // instantiated since there are only 2 cases: + // - This thread was responsible for building this previous context. Then + // this context has already been built. + // - Another thread was assigned to build this context. Then, it doesn't + // matter if this other thread has already finished the process, the + // current thread will have to wait until the guard is released. + if let Some(previous_context) = self + .stack + .get(i - 1) + .map(|id| self.context_vec[*id].1.read().unwrap()) + { + let previous_context = previous_context.as_ref().unwrap(); + let previous_substitution = + previous_context.cumulative_substitution.as_ref().unwrap(); + + for (k, v) in &previous_substitution.map { + let value = match simultaneous.get(v) { + Some(new_value) => new_value, + None => v, + }; + cumulative_substitution.insert(k.clone(), value.clone()); + } + } + } + drop(context_guard); + + // Waits until the OS allows to mutate at this context + // TODO: Does it really needs to require a write guard here instead of up there + let mut context_guard = self.context_vec[self.stack[i]].1.write().unwrap(); + context_guard.as_mut().unwrap().cumulative_substitution = + Some(Substitution::new(pool, cumulative_substitution).unwrap()); + self.num_cumulative_calculated = i + 1; + } + } + + pub fn apply_previous(&mut self, pool: &mut dyn TermPool, term: &Rc) -> Rc { + if self.len() < 2 { + term.clone() + } else { + let index = self.len() - 2; + self.catch_up_cumulative(pool, index); + self.context_vec[self.stack[index]] + .1 + .write() + .unwrap() + .as_mut() + .unwrap() + .cumulative_substitution + .as_mut() + .unwrap() + .apply(pool, term) + } + } + + pub fn apply(&mut self, pool: &mut dyn TermPool, term: &Rc) -> Rc { + if self.is_empty() { + term.clone() + } else { + let index = self.len() - 1; + self.catch_up_cumulative(pool, index); + self.context_vec[self.stack[index]] + .1 + .write() + .unwrap() + .as_mut() + .unwrap() + .cumulative_substitution + .as_mut() + .unwrap() + .apply(pool, term) + } + } +} + +fn build_simultaneous_substitution( + pool: &mut dyn TermPool, + mappings: &[(Rc, Rc)], +) -> Substitution { + let mut result = Substitution::empty(); + + // We build the simultaneous substitution from the bottom up, by using the mappings already + // introduced to transform the result of a new mapping before inserting it into the + // substitution. So for instance, if the mappings are `y -> z` and `x -> (f y)`, we insert the + // first mapping, and then, when introducing the second, we use the current state of the + // substitution to transform `(f y)` into `(f z)`. The result will then contain `y -> z` and + // `x -> (f z)`. + for (var, value) in mappings { + let new_value = result.apply(pool, value); + + // We can unwrap here safely because, by construction, the sort of `var` is the + // same as the sort of `new_value` + result.insert(pool, var.clone(), new_value).unwrap(); + } + result +} diff --git a/carcara/src/ast/iter.rs b/carcara/src/ast/iter.rs index c9b16847..e3788aab 100644 --- a/carcara/src/ast/iter.rs +++ b/carcara/src/ast/iter.rs @@ -28,7 +28,7 @@ use super::*; /// (step t5 (cl) :rule resolution :premises (t4 h1 h2)) /// " /// .as_bytes(); -/// let (_, proof, _) = parser::parse_instance("".as_bytes(), proof, true, false, false)?; +/// let (_, proof, _) = parser::parse_instance("".as_bytes(), proof, parser::Config::new())?; /// let ids: Vec<_> = proof.iter().map(|c| c.id()).collect(); /// assert_eq!(ids, ["h1", "h2", "t3", "t3.t1", "t3.t2", "t3", "t4", "t5"]); /// # Ok(()) diff --git a/carcara/src/ast/macros.rs b/carcara/src/ast/macros.rs index ffd55e62..42de35cf 100644 --- a/carcara/src/ast/macros.rs +++ b/carcara/src/ast/macros.rs @@ -22,7 +22,7 @@ /// Removing two leading negations from a term: /// ``` /// # use carcara::{ast::*, build_term, match_term}; -/// # let mut pool = TermPool::new(); +/// # let mut pool = PrimitivePool::new(); /// # let t = build_term!(pool, (not (not {pool.bool_false()}))); /// let p = match_term!((not (not p)) = t).unwrap(); /// ``` @@ -31,8 +31,8 @@ /// ``` /// # use carcara::{ast::*, match_term, parser::*}; /// # pub fn parse_term(input: &str) -> Rc { -/// # let mut pool = TermPool::new(); -/// # let mut parser = Parser::new(&mut pool, input.as_bytes(), true, false, false).unwrap(); +/// # let mut pool = PrimitivePool::new(); +/// # let mut parser = Parser::new(&mut pool, Config::new(), input.as_bytes()).unwrap(); /// # parser.parse_term().unwrap() /// # } /// # let t = parse_term("(and (=> false false) (> (+ 0 0) 0))"); @@ -42,7 +42,7 @@ /// Pattern matching against boolean constants: /// ``` /// # use carcara::{ast::*, build_term, match_term}; -/// # let mut pool = TermPool::new(); +/// # let mut pool = PrimitivePool::new(); /// # let t = build_term!(pool, (or {pool.bool_false()} {pool.bool_false()})); /// let (p, ()) = match_term!((or p false) = t).unwrap(); /// ``` @@ -51,8 +51,8 @@ /// ``` /// # use carcara::{ast::*, match_term, parser::*}; /// # pub fn parse_term(input: &str) -> Rc { -/// # let mut pool = TermPool::new(); -/// # let mut parser = Parser::new(&mut pool, input.as_bytes(), true, false, false).unwrap(); +/// # let mut pool = PrimitivePool::new(); +/// # let mut parser = Parser::new(&mut pool, Config::new(), input.as_bytes()).unwrap(); /// # parser.parse_term().unwrap() /// # } /// # let t = parse_term("(forall ((x Int) (y Int)) (> x y))"); @@ -62,7 +62,7 @@ /// Pattern matching against a variable number of arguments: /// ``` /// # use carcara::{ast::*, build_term, match_term}; -/// # let mut pool = TermPool::new(); +/// # let mut pool = PrimitivePool::new(); /// # let t = build_term!(pool, (and {pool.bool_false()} {pool.bool_false()})); /// let args: &[Rc] = match_term!((and ...) = t).unwrap(); /// ``` @@ -175,7 +175,7 @@ macro_rules! match_term_err { /// Building the term `(and true (not false))`: /// ``` /// # use carcara::{ast::*, build_term, match_term}; -/// let mut pool = TermPool::new(); +/// let mut pool = PrimitivePool::new(); /// let t = build_term!(pool, (and {pool.bool_true()} (not {pool.bool_false()}))); /// assert!(match_term!((and true (not false)) = t).is_some()); /// ``` @@ -249,13 +249,13 @@ macro_rules! impl_str_conversion_traits { #[cfg(test)] mod tests { - use crate::ast::*; + use crate::ast::{pool::PrimitivePool, *}; use crate::parser::tests::{parse_term, parse_terms}; #[test] fn test_match_term() { - let mut p = TermPool::new(); - let [one, two, five] = [1, 2, 5].map(|n| p.add(Term::integer(n))); + let mut p = PrimitivePool::new(); + let [one, two, five] = [1, 2, 5].map(|n| p.add(Term::new_int(n))); let term = parse_term(&mut p, "(= (= (not false) (= true false)) (not true))"); let ((a, (b, c)), d) = match_term!((= (= (not a) (= b c)) (not d)) = &term).unwrap(); @@ -303,13 +303,13 @@ mod tests { (declare-fun p () Bool) (declare-fun q () Bool) "; - let mut pool = TermPool::new(); + let mut pool = PrimitivePool::new(); let bool_sort = pool.add(Term::Sort(Sort::Bool)); let int_sort = pool.add(Term::Sort(Sort::Int)); - let [one, two, three] = [1, 2, 3].map(|n| pool.add(Term::integer(n))); - let [a, b] = ["a", "b"].map(|s| pool.add(Term::var(s, int_sort.clone()))); - let [p, q] = ["p", "q"].map(|s| pool.add(Term::var(s, bool_sort.clone()))); + let [one, two, three] = [1, 2, 3].map(|n| pool.add(Term::new_int(n))); + let [a, b] = ["a", "b"].map(|s| pool.add(Term::new_var(s, int_sort.clone()))); + let [p, q] = ["p", "q"].map(|s| pool.add(Term::new_var(s, bool_sort.clone()))); let cases = [ ("(= a b)", build_term!(pool, (= {a} {b}))), diff --git a/carcara/src/ast/mod.rs b/carcara/src/ast/mod.rs index 424e59d6..51a513f7 100644 --- a/carcara/src/ast/mod.rs +++ b/carcara/src/ast/mod.rs @@ -4,26 +4,28 @@ #[macro_use] mod macros; -mod deep_eq; +mod context; mod iter; -mod pool; +mod polyeq; +pub mod pool; pub(crate) mod printer; mod rc; mod substitution; #[cfg(test)] mod tests; -pub use deep_eq::{are_alpha_equivalent, deep_eq, tracing_deep_eq}; +pub use context::{Context, ContextStack}; pub use iter::ProofIter; -pub use pool::TermPool; +pub use polyeq::{alpha_equiv, polyeq, tracing_polyeq}; +pub use pool::{PrimitivePool, TermPool}; pub use printer::print_proof; pub use rc::Rc; pub use substitution::{Substitution, SubstitutionError}; -pub(crate) use deep_eq::{DeepEq, DeepEqualityChecker}; +pub(crate) use polyeq::{Polyeq, PolyeqComparator}; use crate::checker::error::CheckerError; -use ahash::AHashSet; +use indexmap::IndexSet; use rug::Integer; use rug::Rational; use std::{hash::Hash, ops::Deref}; @@ -49,7 +51,7 @@ pub struct Proof { /// The proof's premises. /// /// Those are the terms introduced in the original problem's `assert` commands. - pub premises: AHashSet>, + pub premises: IndexSet>, /// The proof commands. pub commands: Vec, @@ -158,6 +160,9 @@ pub struct Subproof { /// The "variable" style arguments of the subproof, of the form `( )`. pub variable_args: Vec, + + /// Subproof id used for context hashing purpose + pub context_id: usize, } /// An argument for a `step` command. @@ -370,7 +375,7 @@ impl AsRef<[SortedVar]> for BindingList { } impl Deref for BindingList { - type Target = [SortedVar]; + type Target = Vec; fn deref(&self) -> &Self::Target { &self.0 @@ -389,11 +394,6 @@ impl<'a> IntoIterator for &'a BindingList { impl BindingList { pub const EMPTY: &'static Self = &BindingList(Vec::new()); - - /// Extract a slice of the binding list's contents. - pub fn as_slice(&self) -> &[SortedVar] { - self.0.as_slice() - } } /// A term. @@ -401,8 +401,11 @@ impl BindingList { /// Many additional methods are implemented in [`Rc`]. #[derive(Clone, PartialEq, Eq, Hash)] pub enum Term { - /// A terminal. This can be a constant or a variable. - Terminal(Terminal), + /// A constant term. + Const(Constant), + + /// A variable, consisting of an identifier and a sort. + Var(Ident, Rc), /// An application of a function to one or more terms. App(Rc, Vec>), @@ -428,50 +431,47 @@ pub enum Term { impl From for Term { fn from(var: SortedVar) -> Self { - Term::Terminal(Terminal::Var(Identifier::Simple(var.0), var.1)) + Term::Var(Ident::Simple(var.0), var.1) } } impl Term { /// Constructs a new integer term. - pub fn integer(value: impl Into) -> Self { - Term::Terminal(Terminal::Integer(value.into())) + pub fn new_int(value: impl Into) -> Self { + Term::Const(Constant::Integer(value.into())) } /// Constructs a new real term. - pub fn real(value: impl Into) -> Self { - Term::Terminal(Terminal::Real(value.into())) + pub fn new_real(value: impl Into) -> Self { + Term::Const(Constant::Real(value.into())) } /// Constructs a new string term. - pub fn string(value: impl Into) -> Self { - Term::Terminal(Terminal::String(value.into())) + pub fn new_string(value: impl Into) -> Self { + Term::Const(Constant::String(value.into())) } /// Constructs a new variable term. - pub fn var(name: impl Into, sort: Rc) -> Self { - Term::Terminal(Terminal::Var(Identifier::Simple(name.into()), sort)) + pub fn new_var(name: impl Into, sort: Rc) -> Self { + Term::Var(Ident::Simple(name.into()), sort) } /// Returns the sort of this term. This does not make use of a cache --- if possible, prefer to /// use `TermPool::sort`. pub fn raw_sort(&self) -> Sort { - let mut pool = TermPool::new(); + let mut pool = PrimitivePool::new(); let added = pool.add(self.clone()); - pool.sort(&added).clone() + pool.sort(&added).as_sort().unwrap().clone() } - /// Returns `true` if the term is a terminal. + /// Returns `true` if the term is a terminal, that is, if it is a constant or a variable. pub fn is_terminal(&self) -> bool { - matches!(self, Term::Terminal(_)) + matches!(self, Term::Const(_) | Term::Var(..)) } /// Returns `true` if the term is an integer or real constant. pub fn is_number(&self) -> bool { - matches!( - self, - Term::Terminal(Terminal::Real(_) | Terminal::Integer(_)) - ) + matches!(self, Term::Const(Constant::Real(_) | Constant::Integer(_))) } /// Returns `true` if the term is an integer or real constant, or one such constant negated @@ -487,8 +487,8 @@ impl Term { /// constant. pub fn as_number(&self) -> Option { match self { - Term::Terminal(Terminal::Real(r)) => Some(r.clone()), - Term::Terminal(Terminal::Integer(i)) => Some(i.clone().into()), + Term::Const(Constant::Real(r)) => Some(r.clone()), + Term::Const(Constant::Integer(i)) => Some(i.clone().into()), _ => None, } } @@ -527,17 +527,14 @@ impl Term { /// Returns `true` if the term is a variable. pub fn is_var(&self) -> bool { - matches!( - self, - Term::Terminal(Terminal::Var(Identifier::Simple(_), _)) - ) + matches!(self, Term::Var(Ident::Simple(_), _)) } /// Tries to extract the variable name from a term. Returns `Some` if the term is a variable /// with a simple identifier. pub fn as_var(&self) -> Option<&str> { match self { - Term::Terminal(Terminal::Var(Identifier::Simple(var), _)) => Some(var.as_str()), + Term::Var(Ident::Simple(var), _) => Some(var.as_str()), _ => None, } } @@ -557,7 +554,7 @@ impl Term { /// Tries to unwrap an operation term, returning the `Operator` and the arguments. Returns /// `None` if the term is not an operation term. - pub fn unwrap_op(&self) -> Option<(Operator, &[Rc])> { + pub fn as_op(&self) -> Option<(Operator, &[Rc])> { match self { Term::Op(op, args) => Some((*op, args.as_slice())), _ => None, @@ -566,7 +563,7 @@ impl Term { /// Tries to unwrap a quantifier term, returning the `Quantifier`, the bindings and the inner /// term. Returns `None` if the term is not a quantifier term. - pub fn unwrap_quant(&self) -> Option<(Quantifier, &BindingList, &Rc)> { + pub fn as_quant(&self) -> Option<(Quantifier, &BindingList, &Rc)> { match self { Term::Quant(q, b, t) => Some((*q, b, t)), _ => None, @@ -575,7 +572,7 @@ impl Term { /// Tries to unwrap a `let` term, returning the bindings and the inner term. Returns `None` if /// the term is not a `let` term. - pub fn unwrap_let(&self) -> Option<(&BindingList, &Rc)> { + pub fn as_let(&self) -> Option<(&BindingList, &Rc)> { match self { Term::Let(b, t) => Some((b, t)), _ => None, @@ -584,7 +581,7 @@ impl Term { /// Returns `true` if the term is the boolean constant `true`. pub fn is_bool_true(&self) -> bool { - if let Term::Terminal(Terminal::Var(Identifier::Simple(name), sort)) = self { + if let Term::Var(Ident::Simple(name), sort) = self { sort.as_sort() == Some(&Sort::Bool) && name == "true" } else { false @@ -593,7 +590,7 @@ impl Term { /// Returns `true` if the term is the boolean constant `false`. pub fn is_bool_false(&self) -> bool { - if let Term::Terminal(Terminal::Var(Identifier::Simple(name), sort)) = self { + if let Term::Var(Ident::Simple(name), sort) = self { sort.as_sort() == Some(&Sort::Bool) && name == "false" } else { false @@ -660,29 +657,29 @@ impl Rc { /// Tries to unwrap an operation term, returning the `Operator` and the arguments. Returns a /// `CheckerError` if the term is not an operation term. - pub fn unwrap_op_err(&self) -> Result<(Operator, &[Rc]), CheckerError> { - self.unwrap_op() + pub fn as_op_err(&self) -> Result<(Operator, &[Rc]), CheckerError> { + self.as_op() .ok_or_else(|| CheckerError::ExpectedOperationTerm(self.clone())) } /// Tries to unwrap a quantifier term, returning the `Quantifier`, the bindings and the inner /// term. Returns a `CheckerError` if the term is not a quantifier term. - pub fn unwrap_quant_err(&self) -> Result<(Quantifier, &BindingList, &Rc), CheckerError> { - self.unwrap_quant() + pub fn as_quant_err(&self) -> Result<(Quantifier, &BindingList, &Rc), CheckerError> { + self.as_quant() .ok_or_else(|| CheckerError::ExpectedQuantifierTerm(self.clone())) } /// Tries to unwrap a `let` term, returning the bindings and the inner /// term. Returns a `CheckerError` if the term is not a `let` term. - pub fn unwrap_let_err(&self) -> Result<(&BindingList, &Rc), CheckerError> { - self.unwrap_let() + pub fn as_let_err(&self) -> Result<(&BindingList, &Rc), CheckerError> { + self.as_let() .ok_or_else(|| CheckerError::ExpectedLetTerm(self.clone())) } } -/// A terminal term. +/// A constant term. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Terminal { +pub enum Constant { /// An integer constant term. Integer(Integer), @@ -691,24 +688,21 @@ pub enum Terminal { /// A string literal term. String(String), - - /// A variable, consisting of an identifier and a sort. - Var(Identifier, Rc), } /// An identifier. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum Identifier { +pub enum Ident { /// A simple identifier, consisting of a symbol. Simple(String), /// An indexed identifier, consisting of a symbol and one or more indices. - Indexed(String, Vec), + Indexed(String, Vec), } /// An index for an indexed identifier. This can be either a numeral or a symbol. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum IdentifierIndex { +pub enum IdentIndex { Numeral(u64), Symbol(String), } diff --git a/carcara/src/ast/deep_eq.rs b/carcara/src/ast/polyeq.rs similarity index 56% rename from carcara/src/ast/deep_eq.rs rename to carcara/src/ast/polyeq.rs index c518cc94..7ba7f592 100644 --- a/carcara/src/ast/deep_eq.rs +++ b/carcara/src/ast/polyeq.rs @@ -1,23 +1,22 @@ //! This module implements less strict definitions of equality for terms. In particular, it //! contains two definitions of equality that differ from `PartialEq`: //! -//! - `deep_eq` considers `=` terms that are reflections of each other as equal, meaning the terms +//! - `polyeq` considers `=` terms that are reflections of each other as equal, meaning the terms //! `(= a b)` and `(= b a)` are considered equal by this method. //! -//! - `are_alpha_equivalent` compares terms by alpha-equivalence, meaning it implements equality of -//! terms modulo renaming of bound variables. +//! - `alpha_equiv` compares terms by alpha-equivalence, meaning it implements equality of terms +//! modulo renaming of bound variables. use super::{ - BindingList, Identifier, Operator, ProofArg, ProofCommand, ProofStep, Rc, Sort, Subproof, Term, - Terminal, + BindingList, Ident, Operator, ProofArg, ProofCommand, ProofStep, Rc, Sort, Subproof, Term, }; -use crate::utils::SymbolTable; +use crate::utils::HashMapStack; use std::time::{Duration, Instant}; /// A trait that represents objects that can be compared for equality modulo reordering of /// equalities or alpha equivalence. -pub trait DeepEq { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool; +pub trait Polyeq { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool; } /// Computes whether the two given terms are equal, modulo reordering of equalities. @@ -26,28 +25,28 @@ pub trait DeepEq { /// equal, meaning terms like `(and p (= a b))` and `(and p (= b a))` are considered equal. /// /// This function records how long it takes to run, and adds that duration to the `time` argument. -pub fn deep_eq(a: &Rc, b: &Rc, time: &mut Duration) -> bool { +pub fn polyeq(a: &Rc, b: &Rc, time: &mut Duration) -> bool { let start = Instant::now(); - let result = DeepEq::eq(&mut DeepEqualityChecker::new(true, false), a, b); + let result = Polyeq::eq(&mut PolyeqComparator::new(true, false), a, b); *time += start.elapsed(); result } -/// Similar to `deep_eq`, but also records the maximum depth the deep equality checker reached when +/// Similar to `polyeq`, but also records the maximum depth the polyequal comparator reached when /// comparing the terms. /// /// This function records how long it takes to run, and adds that duration to the `time` argument. -pub fn tracing_deep_eq(a: &Rc, b: &Rc, time: &mut Duration) -> (bool, usize) { +pub fn tracing_polyeq(a: &Rc, b: &Rc, time: &mut Duration) -> (bool, usize) { let start = Instant::now(); - let mut checker = DeepEqualityChecker::new(true, false); - let result = DeepEq::eq(&mut checker, a, b); + let mut comp = PolyeqComparator::new(true, false); + let result = Polyeq::eq(&mut comp, a, b); *time += start.elapsed(); - (result, checker.max_depth) + (result, comp.max_depth) } -/// Similar to `deep_eq`, but instead compares terms for alpha equivalence. +/// Similar to `polyeq`, but instead compares terms for alpha equivalence. /// /// This means that two terms which are the same, except for the renaming of a bound variable, are /// considered equivalent. This functions still considers equality modulo reordering of equalities. @@ -55,21 +54,21 @@ pub fn tracing_deep_eq(a: &Rc, b: &Rc, time: &mut Duration) -> (bool /// Int)) (= 0 y))` as equivalent. /// /// This function records how long it takes to run, and adds that duration to the `time` argument. -pub fn are_alpha_equivalent(a: &Rc, b: &Rc, time: &mut Duration) -> bool { +pub fn alpha_equiv(a: &Rc, b: &Rc, time: &mut Duration) -> bool { let start = Instant::now(); // When we are checking for alpha-equivalence, we can't always assume that if `a` and `b` are - // identical, they are alpha-equivalent, so that optimization is not used in `DeepEq::eq`. + // identical, they are alpha-equivalent, so that optimization is not used in `Polyeq::eq`. // However, here at the "root" level this assumption is valid, so we check if the terms are // directly equal before doing anything else - let result = a == b || DeepEq::eq(&mut DeepEqualityChecker::new(true, true), a, b); + let result = a == b || Polyeq::eq(&mut PolyeqComparator::new(true, true), a, b); *time += start.elapsed(); result } -/// A configurable checker for equality modulo reordering of equalities and alpha equivalence. -pub struct DeepEqualityChecker { +/// A configurable comparator for polyequality and alpha equivalence. +pub struct PolyeqComparator { // In order to check alpha-equivalence, we can't use a simple global cache. For instance, let's // say we are comparing the following terms for alpha equivalence: // ``` @@ -88,31 +87,31 @@ pub struct DeepEqualityChecker { // are comparing the second argument of each term, `(< x y)` will again be `(< $0 $1)` in `a`, // but it will be `(< $1 $0)` in `b`. If we just rely on the cache, we will incorrectly // determine that `a` and `b` are alpha-equivalent. To account for that, we use a more - // complicated caching system, based on a `SymbolTable`. We push a new scope every time we enter - // a binder term, and pop it as we exit. This unfortunately means that equalities derived + // complicated caching system, based on a `HashMapStack`. We push a new scope every time we + // enter a binder term, and pop it as we exit. This unfortunately means that equalities derived // inside a binder term can't be reused outside of it, degrading performance. If we are not - // checking for alpha-equivalence, we never push an additional scope to this `SymbolTable`, - // meaning it functions as a simple hash set. - cache: SymbolTable<(Rc, Rc), ()>, + // checking for alpha-equivalence, we never push an additional scope to this `HashMapStack`, + // meaning it functions as a simple hash map. + cache: HashMapStack<(Rc, Rc), ()>, is_mod_reordering: bool, - alpha_equiv_checker: Option, + de_bruijn_map: Option, current_depth: usize, max_depth: usize, } -impl DeepEqualityChecker { - /// Constructs a new `DeepEqualityChecker`. +impl PolyeqComparator { + /// Constructs a new `PolyeqComparator`. /// - /// If `is_mod_reordering` is `true`, the checker will compare terms modulo reordering of - /// equalities. If `is_alpha_equivalence` is `true`, the checker will compare terms for alpha + /// If `is_mod_reordering` is `true`, the comparator will compare terms modulo reordering of + /// equalities. If `is_alpha_equivalence` is `true`, the comparator will compare terms for alpha /// equivalence. pub fn new(is_mod_reordering: bool, is_alpha_equivalence: bool) -> Self { Self { is_mod_reordering, - cache: SymbolTable::new(), - alpha_equiv_checker: if is_alpha_equivalence { - Some(AlphaEquivalenceChecker::new()) + cache: HashMapStack::new(), + de_bruijn_map: if is_alpha_equivalence { + Some(DeBruijnMap::new()) } else { None }, @@ -121,234 +120,226 @@ impl DeepEqualityChecker { } } - fn check_binder( + fn compare_binder( &mut self, a_binds: &BindingList, b_binds: &BindingList, a_inner: &Rc, b_inner: &Rc, ) -> bool { - if let Some(alpha_checker) = self.alpha_equiv_checker.as_mut() { - // First, we push new scopes into the alpha-equivalence checker and the cache stack - alpha_checker.push(); + if let Some(de_bruijn_map) = self.de_bruijn_map.as_mut() { + // First, we push new scopes into the De Bruijn map and the cache stack + de_bruijn_map.push(); self.cache.push_scope(); // Then, we check that the binding lists and the inner terms are equivalent for (a_var, b_var) in a_binds.iter().zip(b_binds.iter()) { - if !DeepEq::eq(self, &a_var.1, &b_var.1) { - // We must remember to pop the frames from the alpha equivalence checker and - // cache stack here, so as not to leave them in a corrupted state - self.alpha_equiv_checker.as_mut().unwrap().pop(); + if !Polyeq::eq(self, &a_var.1, &b_var.1) { + // We must remember to pop the frames from the De Bruijn map and cache stack + // here, so as not to leave them in a corrupted state + self.de_bruijn_map.as_mut().unwrap().pop(); self.cache.pop_scope(); return false; } - // We also insert each variable in the binding lists into the alpha-equivalence - // checker - self.alpha_equiv_checker + // We also insert each variable in the binding lists into the De Bruijn map + self.de_bruijn_map .as_mut() .unwrap() .insert(a_var.0.clone(), b_var.0.clone()); } - let result = DeepEq::eq(self, a_inner, b_inner); + let result = Polyeq::eq(self, a_inner, b_inner); // Finally, we pop the scopes we pushed - self.alpha_equiv_checker.as_mut().unwrap().pop(); + self.de_bruijn_map.as_mut().unwrap().pop(); self.cache.pop_scope(); result } else { - DeepEq::eq(self, a_binds, b_binds) && DeepEq::eq(self, a_inner, b_inner) + Polyeq::eq(self, a_binds, b_binds) && Polyeq::eq(self, a_inner, b_inner) } } } -impl DeepEq for Rc { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for Rc { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { // If the two `Rc`s are directly equal, and we are not checking for alpha-equivalence, we // can return `true`. // Note that if we are checking for alpha-equivalence, identical terms may be considered // different, if the bound variables in them have different meanings. For example, in the // terms `(forall ((x Int) (y Int)) (< x y))` and `(forall ((y Int) (x Int)) (< x y))`, // even though both instances of `(< x y)` are identical, they are not alpha-equivalent. - if checker.alpha_equiv_checker.is_none() && a == b { + if comp.de_bruijn_map.is_none() && a == b { return true; } // We first check the cache to see if these terms were already determined to be equal - if checker.cache.get(&(a.clone(), b.clone())).is_some() { + if comp.cache.get(&(a.clone(), b.clone())).is_some() { return true; } - checker.current_depth += 1; - checker.max_depth = std::cmp::max(checker.max_depth, checker.current_depth); - let result = DeepEq::eq(checker, a.as_ref(), b.as_ref()); + comp.current_depth += 1; + comp.max_depth = std::cmp::max(comp.max_depth, comp.current_depth); + let result = Polyeq::eq(comp, a.as_ref(), b.as_ref()); if result { - checker.cache.insert((a.clone(), b.clone()), ()); + comp.cache.insert((a.clone(), b.clone()), ()); } - checker.current_depth -= 1; + comp.current_depth -= 1; result } } -impl DeepEq for Term { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for Term { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { match (a, b) { + (Term::Const(a), Term::Const(b)) => a == b, + (Term::Var(Ident::Simple(a), a_sort), Term::Var(Ident::Simple(b), b_sort)) + if comp.de_bruijn_map.is_some() => + { + // If we are checking for alpha-equivalence, and we encounter two variables, we + // check that they are equivalent using the De Bruijn map + let db = comp.de_bruijn_map.as_mut().unwrap(); + db.compare(a, b) && Polyeq::eq(comp, a_sort, b_sort) + } + (Term::Var(a, a_sort), Term::Var(b, b_sort)) => { + a == b && Polyeq::eq(comp, a_sort, b_sort) + } (Term::App(f_a, args_a), Term::App(f_b, args_b)) => { - DeepEq::eq(checker, f_a, f_b) && DeepEq::eq(checker, args_a, args_b) + Polyeq::eq(comp, f_a, f_b) && Polyeq::eq(comp, args_a, args_b) } (Term::Op(op_a, args_a), Term::Op(op_b, args_b)) => { - if checker.is_mod_reordering { + if comp.is_mod_reordering { if let (Operator::Equals, [a_1, a_2], Operator::Equals, [b_1, b_2]) = (op_a, args_a.as_slice(), op_b, args_b.as_slice()) { // If the term is an equality of two terms, we also check if they would be // equal if one of them was flipped - return DeepEq::eq(checker, &(a_1, a_2), &(b_1, b_2)) - || DeepEq::eq(checker, &(a_1, a_2), &(b_2, b_1)); + return Polyeq::eq(comp, &(a_1, a_2), &(b_1, b_2)) + || Polyeq::eq(comp, &(a_1, a_2), &(b_2, b_1)); } } // General case - op_a == op_b && DeepEq::eq(checker, args_a, args_b) + op_a == op_b && Polyeq::eq(comp, args_a, args_b) } - (Term::Sort(a), Term::Sort(b)) => DeepEq::eq(checker, a, b), - (Term::Terminal(a), Term::Terminal(b)) => match (a, b) { - // If we are checking for alpha-equivalence, and we encounter two variables, we - // check that they are equivalent using the alpha-equivalence checker - ( - Terminal::Var(Identifier::Simple(a_var), a_sort), - Terminal::Var(Identifier::Simple(b_var), b_sort), - ) if checker.alpha_equiv_checker.is_some() => { - let alpha = checker.alpha_equiv_checker.as_mut().unwrap(); - alpha.check(a_var, b_var) && DeepEq::eq(checker, a_sort, b_sort) - } - - (Terminal::Var(iden_a, sort_a), Terminal::Var(iden_b, sort_b)) => { - iden_a == iden_b && DeepEq::eq(checker, sort_a, sort_b) - } - (a, b) => a == b, - }, + (Term::Sort(a), Term::Sort(b)) => Polyeq::eq(comp, a, b), (Term::Quant(q_a, _, _), Term::Quant(q_b, _, _)) if q_a != q_b => false, (Term::Quant(_, a_binds, a), Term::Quant(_, b_binds, b)) | (Term::Let(a_binds, a), Term::Let(b_binds, b)) | (Term::Lambda(a_binds, a), Term::Lambda(b_binds, b)) => { - checker.check_binder(a_binds, b_binds, a, b) + comp.compare_binder(a_binds, b_binds, a, b) } (Term::Choice(a_var, a), Term::Choice(b_var, b)) => { let a_binds = BindingList(vec![a_var.clone()]); let b_binds = BindingList(vec![b_var.clone()]); - checker.check_binder(&a_binds, &b_binds, a, b) + comp.compare_binder(&a_binds, &b_binds, a, b) } _ => false, } } } -impl DeepEq for BindingList { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { - DeepEq::eq(checker, &a.0, &b.0) +impl Polyeq for BindingList { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { + Polyeq::eq(comp, &a.0, &b.0) } } -impl DeepEq for Sort { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for Sort { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { match (a, b) { (Sort::Function(sorts_a), Sort::Function(sorts_b)) => { - DeepEq::eq(checker, sorts_a, sorts_b) + Polyeq::eq(comp, sorts_a, sorts_b) } (Sort::Atom(a, sorts_a), Sort::Atom(b, sorts_b)) => { - a == b && DeepEq::eq(checker, sorts_a, sorts_b) + a == b && Polyeq::eq(comp, sorts_a, sorts_b) } (Sort::Bool, Sort::Bool) | (Sort::Int, Sort::Int) | (Sort::Real, Sort::Real) | (Sort::String, Sort::String) => true, (Sort::Array(x_a, y_a), Sort::Array(x_b, y_b)) => { - DeepEq::eq(checker, x_a, x_b) && DeepEq::eq(checker, y_a, y_b) + Polyeq::eq(comp, x_a, x_b) && Polyeq::eq(comp, y_a, y_b) } _ => false, } } } -impl DeepEq for &T { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { - DeepEq::eq(checker, *a, *b) +impl Polyeq for &T { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { + Polyeq::eq(comp, *a, *b) } } -impl DeepEq for [T] { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { - a.len() == b.len() - && a.iter() - .zip(b.iter()) - .all(|(a, b)| DeepEq::eq(checker, a, b)) +impl Polyeq for [T] { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { + a.len() == b.len() && a.iter().zip(b.iter()).all(|(a, b)| Polyeq::eq(comp, a, b)) } } -impl DeepEq for Vec { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { - DeepEq::eq(checker, a.as_slice(), b.as_slice()) +impl Polyeq for Vec { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { + Polyeq::eq(comp, a.as_slice(), b.as_slice()) } } -impl DeepEq for (T, U) { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { - DeepEq::eq(checker, &a.0, &b.0) && DeepEq::eq(checker, &a.1, &b.1) +impl Polyeq for (T, U) { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { + Polyeq::eq(comp, &a.0, &b.0) && Polyeq::eq(comp, &a.1, &b.1) } } -impl DeepEq for String { - fn eq(_: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for String { + fn eq(_: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { a == b } } -impl DeepEq for ProofArg { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for ProofArg { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { match (a, b) { - (ProofArg::Term(a), ProofArg::Term(b)) => DeepEq::eq(checker, a, b), + (ProofArg::Term(a), ProofArg::Term(b)) => Polyeq::eq(comp, a, b), (ProofArg::Assign(sa, ta), ProofArg::Assign(sb, tb)) => { - sa == sb && DeepEq::eq(checker, ta, tb) + sa == sb && Polyeq::eq(comp, ta, tb) } _ => false, } } } -impl DeepEq for ProofCommand { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for ProofCommand { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { match (a, b) { ( ProofCommand::Assume { id: a_id, term: a_term }, ProofCommand::Assume { id: b_id, term: b_term }, - ) => a_id == b_id && DeepEq::eq(checker, a_term, b_term), - (ProofCommand::Step(a), ProofCommand::Step(b)) => DeepEq::eq(checker, a, b), - (ProofCommand::Subproof(a), ProofCommand::Subproof(b)) => DeepEq::eq(checker, a, b), + ) => a_id == b_id && Polyeq::eq(comp, a_term, b_term), + (ProofCommand::Step(a), ProofCommand::Step(b)) => Polyeq::eq(comp, a, b), + (ProofCommand::Subproof(a), ProofCommand::Subproof(b)) => Polyeq::eq(comp, a, b), _ => false, } } } -impl DeepEq for ProofStep { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { +impl Polyeq for ProofStep { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { a.id == b.id - && DeepEq::eq(checker, &a.clause, &b.clause) + && Polyeq::eq(comp, &a.clause, &b.clause) && a.rule == b.rule && a.premises == b.premises - && DeepEq::eq(checker, &a.args, &b.args) + && Polyeq::eq(comp, &a.args, &b.args) && a.discharge == b.discharge } } -impl DeepEq for Subproof { - fn eq(checker: &mut DeepEqualityChecker, a: &Self, b: &Self) -> bool { - DeepEq::eq(checker, &a.commands, &b.commands) - && DeepEq::eq(checker, &a.assignment_args, &b.assignment_args) - && DeepEq::eq(checker, &a.variable_args, &b.variable_args) +impl Polyeq for Subproof { + fn eq(comp: &mut PolyeqComparator, a: &Self, b: &Self) -> bool { + Polyeq::eq(comp, &a.commands, &b.commands) + && Polyeq::eq(comp, &a.assignment_args, &b.assignment_args) + && Polyeq::eq(comp, &a.variable_args, &b.variable_args) } } -struct AlphaEquivalenceChecker { +struct DeBruijnMap { // To check for alpha-equivalence, we make use of De Bruijn indices. The idea is to map each // bound variable to an integer depending on the order in which they were bound. As we compare // the two terms, if we encounter two bound variables, we need only to check if the associated @@ -367,14 +358,16 @@ struct AlphaEquivalenceChecker { // that is bound second are assigned `$1`, etc. The given term would then be represented like // this: // `(forall ((x Int)) (and (exists ((y Int)) (> $0 $1)) (> $0 5)))` - indices: (SymbolTable, SymbolTable), - counter: Vec, // Holds the count of how many variables were bound before each depth + indices: (HashMapStack, HashMapStack), + + // Holds the count of how many variables were bound before each depth + counter: Vec, } -impl AlphaEquivalenceChecker { +impl DeBruijnMap { fn new() -> Self { Self { - indices: (SymbolTable::new(), SymbolTable::new()), + indices: (HashMapStack::new(), HashMapStack::new()), counter: vec![0], } } @@ -390,7 +383,7 @@ impl AlphaEquivalenceChecker { self.indices.0.pop_scope(); self.indices.1.pop_scope(); - // If we successfully popped the scopes from the symbol tables, that means that there was + // If we successfully popped the scopes from the indices stacks, that means that there was // at least one scope, so we can safely pop from the counter stack as well self.counter.pop(); } @@ -402,7 +395,7 @@ impl AlphaEquivalenceChecker { *current += 1; } - fn check(&self, a: &str, b: &str) -> bool { + fn compare(&self, a: &str, b: &str) -> bool { match (self.indices.0.get(a), self.indices.1.get(b)) { // If both a and b are free variables, they need to have the same name (None, None) => a == b, diff --git a/carcara/src/ast/pool.rs b/carcara/src/ast/pool.rs deleted file mode 100644 index 7e9cb27d..00000000 --- a/carcara/src/ast/pool.rs +++ /dev/null @@ -1,252 +0,0 @@ -//! This module implements `TermPool`, a structure that stores terms and implements hash consing. - -use super::{Identifier, Rc, Sort, Term, Terminal}; -use ahash::{AHashMap, AHashSet}; - -/// A structure to store and manage all allocated terms. -/// -/// You can add a `Term` to the pool using [`TermPool::add`], which will return an `Rc`. This -/// struct ensures that, if two equal terms are added to a pool, they will be in the same -/// allocation. This invariant allows terms to be safely compared and hashed by reference, instead -/// of by value (see [`Rc`]). -/// -/// This struct also provides other utility methods, like computing the sort of a term (see -/// [`TermPool::sort`]) or its free variables (see [`TermPool::free_vars`]). -pub struct TermPool { - /// A map of the terms in the pool. - pub(crate) terms: AHashMap>, - free_vars_cache: AHashMap, AHashSet>>, - sorts_cache: AHashMap, Sort>, - bool_true: Rc, - bool_false: Rc, -} - -impl Default for TermPool { - fn default() -> Self { - Self::new() - } -} - -impl TermPool { - /// Constructs a new `TermPool`. This new pool will already contain the boolean constants `true` - /// and `false`, as well as the `Bool` sort. - pub fn new() -> Self { - let mut terms = AHashMap::new(); - let mut sorts_cache = AHashMap::new(); - let bool_sort = Self::add_term_to_map(&mut terms, Term::Sort(Sort::Bool)); - - let [bool_true, bool_false] = ["true", "false"].map(|b| { - Self::add_term_to_map( - &mut terms, - Term::Terminal(Terminal::Var( - Identifier::Simple(b.into()), - bool_sort.clone(), - )), - ) - }); - - sorts_cache.insert(bool_false.clone(), Sort::Bool); - sorts_cache.insert(bool_true.clone(), Sort::Bool); - sorts_cache.insert(bool_sort, Sort::Bool); - - Self { - terms, - free_vars_cache: AHashMap::new(), - sorts_cache, - bool_true, - bool_false, - } - } - - /// Returns the term corresponding to the boolean constant `true`. - pub fn bool_true(&self) -> Rc { - self.bool_true.clone() - } - - /// Returns the term corresponding to the boolean constant `false`. - pub fn bool_false(&self) -> Rc { - self.bool_false.clone() - } - - /// Returns the term corresponding to the boolean constant determined by `value`. - pub fn bool_constant(&self, value: bool) -> Rc { - match value { - true => self.bool_true(), - false => self.bool_false(), - } - } - - fn add_term_to_map(terms_map: &mut AHashMap>, term: Term) -> Rc { - use std::collections::hash_map::Entry; - - match terms_map.entry(term) { - Entry::Occupied(occupied_entry) => occupied_entry.get().clone(), - Entry::Vacant(vacant_entry) => { - let term = vacant_entry.key().clone(); - vacant_entry.insert(Rc::new(term)).clone() - } - } - } - - /// Takes a term and returns a possibly newly allocated `Rc` that references it. - /// - /// If the term was not originally in the term pool, it is added to it. Otherwise, this method - /// just returns an `Rc` pointing to the existing allocation. This method also computes the - /// term's sort, and adds it to the sort cache. - pub fn add(&mut self, term: Term) -> Rc { - let term = Self::add_term_to_map(&mut self.terms, term); - self.compute_sort(&term); - term - } - - /// Takes a vector of terms and calls [`TermPool::add`] on each. - pub fn add_all(&mut self, terms: Vec) -> Vec> { - terms.into_iter().map(|t| self.add(t)).collect() - } - - /// Returns the sort of the given term. - /// - /// This method assumes that the sorts of any subterms have already been checked, and are - /// correct. If `term` is itself a sort, this simply returns that sort. - pub fn sort(&self, term: &Rc) -> &Sort { - &self.sorts_cache[term] - } - - /// Computes the sort of a term and adds it to the sort cache. - fn compute_sort<'a, 'b: 'a>(&'a mut self, term: &'b Rc) -> &'a Sort { - use super::Operator; - - if self.sorts_cache.contains_key(term) { - return &self.sorts_cache[term]; - } - - let result = match term.as_ref() { - Term::Terminal(t) => match t { - Terminal::Integer(_) => Sort::Int, - Terminal::Real(_) => Sort::Real, - Terminal::String(_) => Sort::String, - Terminal::Var(_, sort) => sort.as_sort().unwrap().clone(), - }, - Term::Op(op, args) => match op { - Operator::Not - | Operator::Implies - | Operator::And - | Operator::Or - | Operator::Xor - | Operator::Equals - | Operator::Distinct - | Operator::LessThan - | Operator::GreaterThan - | Operator::LessEq - | Operator::GreaterEq - | Operator::IsInt => Sort::Bool, - Operator::Ite => self.compute_sort(&args[1]).clone(), - Operator::Add | Operator::Sub | Operator::Mult => { - if args.iter().any(|a| *self.compute_sort(a) == Sort::Real) { - Sort::Real - } else { - Sort::Int - } - } - Operator::RealDiv | Operator::ToReal => Sort::Real, - Operator::IntDiv | Operator::Mod | Operator::Abs | Operator::ToInt => Sort::Int, - Operator::Select => match self.compute_sort(&args[0]) { - Sort::Array(_, y) => y.as_sort().unwrap().clone(), - _ => unreachable!(), - }, - Operator::Store => self.compute_sort(&args[0]).clone(), - }, - Term::App(f, _) => { - match self.compute_sort(f) { - Sort::Function(sorts) => sorts.last().unwrap().as_sort().unwrap().clone(), - _ => unreachable!(), // We assume that the function is correctly sorted - } - } - Term::Sort(sort) => sort.clone(), - Term::Quant(_, _, _) => Sort::Bool, - Term::Choice((_, sort), _) => sort.as_sort().unwrap().clone(), - Term::Let(_, inner) => self.compute_sort(inner).clone(), - Term::Lambda(bindings, body) => { - let mut result: Vec<_> = - bindings.iter().map(|(_name, sort)| sort.clone()).collect(); - let return_sort = Term::Sort(self.compute_sort(body).clone()); - result.push(self.add(return_sort)); - Sort::Function(result) - } - }; - self.sorts_cache.insert(term.clone(), result); - &self.sorts_cache[term] - } - - /// Returns an `AHashSet` containing all the free variables in the given term. - /// - /// This method uses a cache, so there is no additional cost to computing the free variables of - /// a term multiple times. - pub fn free_vars(&mut self, term: &Rc) -> &AHashSet> { - // Here, I would like to do - // ``` - // if let Some(vars) = self.free_vars_cache.get(term) { - // return vars; - // } - // ``` - // However, because of a limitation in the borrow checker, the compiler thinks that - // this immutable borrow of `cache` has to live until the end of the function, even - // though the code immediately returns. This would stop me from mutating `cache` in the - // rest of the function. Because of that, I have to check if the hash map contains - // `term` as a key, and then get the value associated with it, meaning I have to access - // the hash map twice, which is a bit slower. This is an example of problem case #3 - // from the non-lexical lifetimes RFC: - // https://github.com/rust-lang/rfcs/blob/master/text/2094-nll.md - if self.free_vars_cache.contains_key(term) { - return self.free_vars_cache.get(term).unwrap(); - } - let set = match term.as_ref() { - Term::App(f, args) => { - let mut set = self.free_vars(f).clone(); - for a in args { - set.extend(self.free_vars(a).iter().cloned()); - } - set - } - Term::Op(_, args) => { - let mut set = AHashSet::new(); - for a in args { - set.extend(self.free_vars(a).iter().cloned()); - } - set - } - Term::Quant(_, bindings, inner) | Term::Lambda(bindings, inner) => { - let mut vars = self.free_vars(inner).clone(); - for bound_var in bindings { - let term = self.add(bound_var.clone().into()); - vars.remove(&term); - } - vars - } - Term::Let(bindings, inner) => { - let mut vars = self.free_vars(inner).clone(); - for (var, value) in bindings { - let sort = Term::Sort(self.sort(value).clone()); - let sort = self.add(sort); - let term = self.add((var.clone(), sort).into()); - vars.remove(&term); - } - vars - } - Term::Choice(bound_var, inner) => { - let mut vars = self.free_vars(inner).clone(); - let term = self.add(bound_var.clone().into()); - vars.remove(&term); - vars - } - Term::Terminal(Terminal::Var(Identifier::Simple(_), _)) => { - let mut set = AHashSet::with_capacity(1); - set.insert(term.clone()); - set - } - Term::Terminal(_) | Term::Sort(_) => AHashSet::new(), - }; - self.free_vars_cache.insert(term.clone(), set); - self.free_vars_cache.get(term).unwrap() - } -} diff --git a/carcara/src/ast/pool/advanced.rs b/carcara/src/ast/pool/advanced.rs new file mode 100644 index 00000000..abad5c1e --- /dev/null +++ b/carcara/src/ast/pool/advanced.rs @@ -0,0 +1,152 @@ +use super::super::{Rc, Term}; +use super::{PrimitivePool, TermPool}; +use indexmap::IndexSet; +use std::sync::{Arc, RwLock}; + +pub struct ContextPool { + pub(crate) global_pool: Arc, + pub(crate) inner: Arc>, +} + +impl Default for ContextPool { + fn default() -> Self { + Self::new() + } +} + +impl ContextPool { + pub fn new() -> Self { + Self { + global_pool: Arc::new(PrimitivePool::new()), + inner: Arc::new(RwLock::new(PrimitivePool::new())), + } + } + + pub fn from_global(global_pool: &Arc) -> Self { + Self { + global_pool: global_pool.clone(), + inner: Arc::new(RwLock::new(PrimitivePool::new())), + } + } + + pub fn from_previous(ctx_pool: &Self) -> Self { + Self { + global_pool: ctx_pool.global_pool.clone(), + inner: ctx_pool.inner.clone(), + } + } +} + +impl TermPool for ContextPool { + fn bool_true(&self) -> Rc { + self.global_pool.bool_true.clone() + } + + fn bool_false(&self) -> Rc { + self.global_pool.bool_false.clone() + } + + fn add(&mut self, term: Term) -> Rc { + // If the global pool has the term + if let Some(entry) = self.global_pool.storage.get(&term) { + return entry.clone(); + } + let mut ctx_guard = self.inner.write().unwrap(); + let term = ctx_guard.storage.add(term); + ctx_guard.compute_sort(&term); + term + } + + fn sort(&self, term: &Rc) -> Rc { + if let Some(sort) = self.global_pool.sorts_cache.get(term) { + sort.clone() + } + // A sort inserted by context + else { + self.inner.read().unwrap().sorts_cache[term].clone() + } + } + + fn free_vars(&mut self, term: &Rc) -> IndexSet> { + self.inner + .write() + .unwrap() + .free_vars_with_priorities(term, [&self.global_pool]) + } +} + +// ========================================================================= + +pub struct LocalPool { + pub(crate) ctx_pool: ContextPool, + pub(crate) inner: PrimitivePool, +} + +impl Default for LocalPool { + fn default() -> Self { + Self::new() + } +} + +impl LocalPool { + pub fn new() -> Self { + Self { + ctx_pool: ContextPool::new(), + inner: PrimitivePool::new(), + } + } + + /// Instantiates a new `LocalPool` from a previous `ContextPool` (makes + /// sure the context is shared between threads). + pub fn from_previous(ctx_pool: &ContextPool) -> Self { + Self { + ctx_pool: ContextPool::from_previous(ctx_pool), + inner: PrimitivePool::new(), + } + } +} + +impl TermPool for LocalPool { + fn bool_true(&self) -> Rc { + self.ctx_pool.global_pool.bool_true.clone() + } + + fn bool_false(&self) -> Rc { + self.ctx_pool.global_pool.bool_false.clone() + } + + fn add(&mut self, term: Term) -> Rc { + // If there is a constant pool and has the term + if let Some(entry) = self.ctx_pool.global_pool.storage.get(&term) { + entry.clone() + } + // If this term was inserted by the context + else if let Some(entry) = self.ctx_pool.inner.read().unwrap().storage.get(&term) { + entry.clone() + } else { + self.inner.add(term) + } + } + + fn sort(&self, term: &Rc) -> Rc { + if let Some(sort) = self.ctx_pool.global_pool.sorts_cache.get(term) { + sort.clone() + } + // A sort inserted by context + else if let Some(entry) = self.ctx_pool.inner.read().unwrap().storage.get(term) { + entry.clone() + } else { + self.inner.sorts_cache[term].clone() + } + } + + fn free_vars(&mut self, term: &Rc) -> IndexSet> { + self.inner.free_vars_with_priorities( + term, + [ + &self.ctx_pool.global_pool, + &self.ctx_pool.inner.read().unwrap(), + ], + ) + } +} diff --git a/carcara/src/ast/pool/mod.rs b/carcara/src/ast/pool/mod.rs new file mode 100644 index 00000000..a70ede42 --- /dev/null +++ b/carcara/src/ast/pool/mod.rs @@ -0,0 +1,276 @@ +//! This module implements `TermPool`, a structure that stores terms and implements hash consing. + +pub mod advanced; +mod storage; + +use super::{Rc, Sort, Term}; +use crate::ast::Constant; +use indexmap::{IndexMap, IndexSet}; +use storage::Storage; + +pub trait TermPool { + /// Returns the term corresponding to the boolean constant `true`. + fn bool_true(&self) -> Rc; + /// Returns the term corresponding to the boolean constant `false`. + fn bool_false(&self) -> Rc; + /// Returns the term corresponding to the boolean constant determined by `value`. + fn bool_constant(&self, value: bool) -> Rc { + match value { + true => self.bool_true(), + false => self.bool_false(), + } + } + /// Takes a term and returns a possibly newly allocated `Rc` that references it. + /// + /// If the term was not originally in the term pool, it is added to it. Otherwise, this method + /// just returns an `Rc` pointing to the existing allocation. This method also computes the + /// term's sort, and adds it to the sort cache. + fn add(&mut self, term: Term) -> Rc; + /// Takes a vector of terms and calls [`TermPool::add`] on each. + fn add_all(&mut self, terms: Vec) -> Vec> { + terms.into_iter().map(|t| self.add(t)).collect() + } + /// Returns the sort of the given term. + /// + /// This method assumes that the sorts of any subterms have already been checked, and are + /// correct. If `term` is itself a sort, this simply returns that sort. + fn sort(&self, term: &Rc) -> Rc; + /// Returns an `IndexSet` containing all the free variables in the given term. + /// + /// This method uses a cache, so there is no additional cost to computing the free variables of + /// a term multiple times. + fn free_vars(&mut self, term: &Rc) -> IndexSet>; +} + +/// A structure to store and manage all allocated terms. +/// +/// You can add a `Term` to the pool using [`PrimitivePool::add`], which will return an `Rc`. This +/// struct ensures that, if two equal terms are added to a pool, they will be in the same +/// allocation. This invariant allows terms to be safely compared and hashed by reference, instead +/// of by value (see [`Rc`]). +/// +/// This struct also provides other utility methods, like computing the sort of a term (see +/// [`PrimitivePool::sort`]) or its free variables (see [`PrimitivePool::free_vars`]). +pub struct PrimitivePool { + pub(crate) storage: Storage, + pub(crate) free_vars_cache: IndexMap, IndexSet>>, + pub(crate) sorts_cache: IndexMap, Rc>, + pub(crate) bool_true: Rc, + pub(crate) bool_false: Rc, +} + +impl Default for PrimitivePool { + fn default() -> Self { + Self::new() + } +} + +impl PrimitivePool { + /// Constructs a new `TermPool`. This new pool will already contain the boolean constants `true` + /// and `false`, as well as the `Bool` sort. + pub fn new() -> Self { + let mut storage = Storage::new(); + let mut sorts_cache = IndexMap::new(); + let bool_sort = storage.add(Term::Sort(Sort::Bool)); + + let [bool_true, bool_false] = + ["true", "false"].map(|b| storage.add(Term::new_var(b, bool_sort.clone()))); + + sorts_cache.insert(bool_false.clone(), bool_sort.clone()); + sorts_cache.insert(bool_true.clone(), bool_sort.clone()); + sorts_cache.insert(bool_sort.clone(), bool_sort); + + Self { + storage, + free_vars_cache: IndexMap::new(), + sorts_cache, + bool_true, + bool_false, + } + } + + /// Computes the sort of a term and adds it to the sort cache. + fn compute_sort(&mut self, term: &Rc) -> Rc { + use super::Operator; + + if let Some(sort) = self.sorts_cache.get(term) { + return sort.clone(); + } + + let result: Sort = match term.as_ref() { + Term::Const(c) => match c { + Constant::Integer(_) => Sort::Int, + Constant::Real(_) => Sort::Real, + Constant::String(_) => Sort::String, + }, + Term::Var(_, sort) => sort.as_sort().unwrap().clone(), + Term::Op(op, args) => match op { + Operator::Not + | Operator::Implies + | Operator::And + | Operator::Or + | Operator::Xor + | Operator::Equals + | Operator::Distinct + | Operator::LessThan + | Operator::GreaterThan + | Operator::LessEq + | Operator::GreaterEq + | Operator::IsInt => Sort::Bool, + Operator::Ite => self.compute_sort(&args[1]).as_sort().unwrap().clone(), + Operator::Add | Operator::Sub | Operator::Mult => { + if args + .iter() + .any(|a| self.compute_sort(a).as_sort().unwrap() == &Sort::Real) + { + Sort::Real + } else { + Sort::Int + } + } + Operator::RealDiv | Operator::ToReal => Sort::Real, + Operator::IntDiv | Operator::Mod | Operator::Abs | Operator::ToInt => Sort::Int, + Operator::Select => match self.compute_sort(&args[0]).as_sort().unwrap() { + Sort::Array(_, y) => y.as_sort().unwrap().clone(), + _ => unreachable!(), + }, + Operator::Store => self.compute_sort(&args[0]).as_sort().unwrap().clone(), + }, + Term::App(f, _) => { + match self.compute_sort(f).as_sort().unwrap() { + Sort::Function(sorts) => sorts.last().unwrap().as_sort().unwrap().clone(), + _ => unreachable!(), // We assume that the function is correctly sorted + } + } + Term::Sort(sort) => sort.clone(), + Term::Quant(_, _, _) => Sort::Bool, + Term::Choice((_, sort), _) => sort.as_sort().unwrap().clone(), + Term::Let(_, inner) => self.compute_sort(inner).as_sort().unwrap().clone(), + Term::Lambda(bindings, body) => { + let mut result: Vec<_> = + bindings.iter().map(|(_name, sort)| sort.clone()).collect(); + result.push(self.compute_sort(body)); + Sort::Function(result) + } + }; + let sort = self.storage.add(Term::Sort(result)); + self.sorts_cache.insert(term.clone(), sort); + self.sorts_cache[term].clone() + } + + fn add_with_priorities( + &mut self, + term: Term, + prior_pools: [&PrimitivePool; N], + ) -> Rc { + for p in prior_pools { + // If this prior pool has the term + if let Some(entry) = p.storage.get(&term) { + return entry.clone(); + } + } + self.add(term) + } + + fn sort_with_priorities( + &mut self, + term: &Rc, + prior_pools: [&PrimitivePool; N], + ) -> Rc { + for p in prior_pools { + if let Some(sort) = p.sorts_cache.get(term) { + return sort.clone(); + } + } + self.sorts_cache[term].clone() + } + + // TODO: Try to workaround the lifetime specifiers and return a ref + pub fn free_vars_with_priorities( + &mut self, + term: &Rc, + prior_pools: [&PrimitivePool; N], + ) -> IndexSet> { + for p in prior_pools { + if let Some(set) = p.free_vars_cache.get(term) { + return set.clone(); + } + } + + if let Some(set) = self.free_vars_cache.get(term) { + return set.clone(); + } + + let set = match term.as_ref() { + Term::App(f, args) => { + let mut set = self.free_vars_with_priorities(f, prior_pools); + for a in args { + set.extend(self.free_vars_with_priorities(a, prior_pools).into_iter()); + } + set + } + Term::Op(_, args) => { + let mut set = IndexSet::new(); + for a in args { + set.extend(self.free_vars_with_priorities(a, prior_pools).into_iter()); + } + set + } + Term::Quant(_, bindings, inner) | Term::Lambda(bindings, inner) => { + let mut vars = self.free_vars_with_priorities(inner, prior_pools); + for bound_var in bindings { + let term = self.add_with_priorities(bound_var.clone().into(), prior_pools); + vars.remove(&term); + } + vars + } + Term::Let(bindings, inner) => { + let mut vars = self.free_vars_with_priorities(inner, prior_pools); + for (var, value) in bindings { + let sort = self.sort_with_priorities(value, prior_pools); + let term = self.add_with_priorities((var.clone(), sort).into(), prior_pools); + vars.remove(&term); + } + vars + } + Term::Choice(bound_var, inner) => { + let mut vars = self.free_vars_with_priorities(inner, prior_pools); + let term = self.add_with_priorities(bound_var.clone().into(), prior_pools); + vars.remove(&term); + vars + } + Term::Var(..) => { + let mut set = IndexSet::with_capacity(1); + set.insert(term.clone()); + set + } + Term::Const(_) | Term::Sort(_) => IndexSet::new(), + }; + self.free_vars_cache.insert(term.clone(), set); + self.free_vars_cache.get(term).unwrap().clone() + } +} + +impl TermPool for PrimitivePool { + fn bool_true(&self) -> Rc { + self.bool_true.clone() + } + + fn bool_false(&self) -> Rc { + self.bool_false.clone() + } + + fn add(&mut self, term: Term) -> Rc { + let term = self.storage.add(term); + self.compute_sort(&term); + term + } + + fn sort(&self, term: &Rc) -> Rc { + self.sorts_cache[term].clone() + } + + fn free_vars(&mut self, term: &Rc) -> IndexSet> { + self.free_vars_with_priorities(term, []) + } +} diff --git a/carcara/src/ast/pool/storage.rs b/carcara/src/ast/pool/storage.rs new file mode 100644 index 00000000..1ebceaf0 --- /dev/null +++ b/carcara/src/ast/pool/storage.rs @@ -0,0 +1,66 @@ +//* The behaviour of the term pool could be modeled by a hash map from `Term` to `Rc`, but +//* that would require allocating two copies of each term, one in the key of the hash map, and one +//* inside the `Rc`. Instead, we store a hash set of `Rc`s, combining the key and the value +//* into a single object. We access this hash set using a `&Term`, and if the entry is present, we +//* clone it; otherwise, we allocate a new `Rc`. + +use crate::ast::*; +use std::borrow::Borrow; + +/// Since `ast::Rc` intentionally implements hashing and equality by reference (instead of by +/// value), we cannot safely implement `Borrow` for `Rc`, so we cannot access a +/// `HashSet>` using a `&Term` as a key. To go around that, we use this struct that wraps +/// an `Rc` and that re-implements hashing and equality by value, meaning we can implement +/// `Borrow` for it, and use it as the contents of the hash set instead. +#[derive(Debug, Clone, Eq)] +struct ByValue(Rc); + +impl PartialEq for ByValue { + fn eq(&self, other: &Self) -> bool { + self.0.as_ref() == other.0.as_ref() + } +} + +impl Hash for ByValue { + fn hash(&self, state: &mut H) { + self.0.as_ref().hash(state); + } +} + +impl Borrow for ByValue { + fn borrow(&self) -> &Term { + self.0.as_ref() + } +} + +#[derive(Debug, Clone, Default)] +pub struct Storage(IndexSet); + +impl Storage { + pub fn new() -> Self { + Self::default() + } + + pub fn add(&mut self, term: Term) -> Rc { + // If the `hash_set_entry` feature was stable, this would be much simpler to do using + // `get_or_insert_with` (and would avoid rehashing the term) + match self.0.get(&term) { + Some(t) => t.0.clone(), + None => { + let result = Rc::new(term); + self.0.insert(ByValue(result.clone())); + result + } + } + } + + pub fn get(&self, term: &Term) -> Option<&Rc> { + self.0.get(term).map(|t| &t.0) + } + + // This method is only necessary for the hash consing tests + #[cfg(test)] + pub fn into_vec(self) -> Vec> { + self.0.into_iter().map(|ByValue(t)| t).collect() + } +} diff --git a/carcara/src/ast/printer.rs b/carcara/src/ast/printer.rs index 6d44f2f9..0d8c15d2 100644 --- a/carcara/src/ast/printer.rs +++ b/carcara/src/ast/printer.rs @@ -5,7 +5,7 @@ use crate::{ parser::Token, utils::{is_symbol_character, DedupIterator}, }; -use ahash::AHashMap; +use indexmap::IndexMap; use std::{borrow::Cow, fmt, io}; /// Prints a proof to the standard output. @@ -17,7 +17,7 @@ pub fn print_proof(commands: &[ProofCommand], use_sharing: bool) -> io::Result<( let mut stdout = io::stdout(); let mut printer = AlethePrinter { inner: &mut stdout, - term_indices: use_sharing.then(AHashMap::new), + term_indices: use_sharing.then(IndexMap::new), term_sharing_variable_prefix: "@p_", }; printer.write_proof(commands) @@ -32,7 +32,7 @@ pub fn write_lia_smt_instance( ) -> io::Result<()> { let mut printer = AlethePrinter { inner: dest, - term_indices: use_sharing.then(AHashMap::new), + term_indices: use_sharing.then(IndexMap::new), term_sharing_variable_prefix: "p_", }; printer.write_lia_smt_instance(clause) @@ -57,7 +57,7 @@ impl PrintWithSharing for Rc { if let Some(indices) = &mut p.term_indices { // There are three cases where we don't use sharing when printing a term: // - // - Terminal terms (e.g., integers, reals, variables, etc.) could in theory be shared, + // - Terminal terms (i.e., constants or variables) could in theory be shared, // but, since they are very small, it's not worth it to give them a name. // // - Sorts are represented as terms, but they are not actually terms in the grammar, so @@ -107,7 +107,7 @@ impl PrintWithSharing for Operator { struct AlethePrinter<'a> { inner: &'a mut dyn io::Write, - term_indices: Option, usize>>, + term_indices: Option, usize>>, term_sharing_variable_prefix: &'static str, } @@ -173,7 +173,8 @@ impl<'a> AlethePrinter<'a> { fn write_raw_term(&mut self, term: &Term) -> io::Result<()> { match term { - Term::Terminal(t) => write!(self.inner, "{}", t), + Term::Const(c) => write!(self.inner, "{}", c), + Term::Var(iden, _) => write!(self.inner, "{}", iden), Term::App(func, args) => self.write_s_expr(func, args), Term::Op(op, args) => self.write_s_expr(op, args), Term::Sort(sort) => write!(self.inner, "{}", sort), @@ -320,7 +321,7 @@ impl fmt::Display for Term { let mut buf = Vec::new(); let mut printer = AlethePrinter { inner: &mut buf, - term_indices: use_sharing.then(AHashMap::new), + term_indices: use_sharing.then(IndexMap::new), term_sharing_variable_prefix: "@p_", }; printer.write_raw_term(self).unwrap(); @@ -335,39 +336,38 @@ impl fmt::Debug for Term { } } -impl fmt::Display for Terminal { +impl fmt::Display for Constant { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Terminal::Integer(i) => write!(f, "{}", i), - Terminal::Real(r) => { + Constant::Integer(i) => write!(f, "{}", i), + Constant::Real(r) => { if r.is_integer() { write!(f, "{:?}.0", r.numer()) } else { write!(f, "{:?}", r.to_f64()) } } - Terminal::String(s) => write!(f, "\"{}\"", escape_string(s)), - Terminal::Var(iden, _) => write!(f, "{}", iden), + Constant::String(s) => write!(f, "\"{}\"", escape_string(s)), } } } -impl fmt::Display for Identifier { +impl fmt::Display for Ident { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Identifier::Simple(s) => write!(f, "{}", quote_symbol(s)), - Identifier::Indexed(s, indices) => { + Ident::Simple(s) => write!(f, "{}", quote_symbol(s)), + Ident::Indexed(s, indices) => { write_s_expr(f, format!("_ {}", quote_symbol(s)), indices) } } } } -impl fmt::Display for IdentifierIndex { +impl fmt::Display for IdentIndex { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - IdentifierIndex::Numeral(n) => write!(f, "{}", n), - IdentifierIndex::Symbol(s) => write!(f, "{}", quote_symbol(s)), + IdentIndex::Numeral(n) => write!(f, "{}", n), + IdentIndex::Symbol(s) => write!(f, "{}", quote_symbol(s)), } } } diff --git a/carcara/src/ast/rc.rs b/carcara/src/ast/rc.rs index 197e2564..27274e3d 100644 --- a/carcara/src/ast/rc.rs +++ b/carcara/src/ast/rc.rs @@ -1,13 +1,33 @@ //! This module implements a variant of `Rc` where equality and hashing are done by reference. -use std::{fmt, hash::Hash, ops::Deref, rc}; +use std::{fmt, hash::Hash, ops::Deref, sync}; -/// An `Rc` where equality and hashing are done by reference, instead of by value. +/// A wrapper for `std::rc::Rc` where equality and hashing are done by reference, instead of by +/// value. /// /// This means that two `Rc`s will not be considered equal and won't have the same hash value unless /// they point to the same allocation. This has the advantage that equality and hashing can be done /// in constant time, even for recursive structures. /// +/// The Carcara parser makes use of hash consing, meaning that each term is only allocated once, +/// even if it appears multiple times in the proof. This means that if we want to compare two terms +/// for equality, we only need to compare them by reference, since if they are equal they will point +/// to the same allocation. However, `std::rc::Rc` implements `PartialEq` by comparing the inner +/// values for equality. If we simply used this implementation, each equality comparison would need +/// to traverse the terms recursively, which would be prohibitively expensive. Instead, this wrapper +/// overrides the `PartialEq` implementation to compare the pointers directly, allowing for constant +/// time equality comparisons. +/// +/// Similarly, when inserting terms in a hash map or set, we can also just hash the pointers +/// instead of recursively hashing the inner value (as `std::rc::Rc`'s `Hash` implementation does). +/// Therefore, this wrapper also overrides the implementation of the `Hash` trait. +/// +/// Note: when using this struct, it's important to avoid constructing terms with `Rc::new` and +/// instead prefer to construct them by adding them to a `TermPool`. This is because `Rc::new` will +/// create a brand new allocation for that term, instead of reusing the existing allocation if that +/// term was already added to the pool. Two indentical terms created independently with `Rc::new` +/// will not compare as equal. +/// /// # Examples /// /// ``` @@ -36,7 +56,7 @@ use std::{fmt, hash::Hash, ops::Deref, rc}; /// assert!(set.contains(&c)); /// ``` #[derive(Eq)] -pub struct Rc(rc::Rc); +pub struct Rc(sync::Arc); // If we simply `#[derive(Clone)]`, it would require that the type parameter `T` also implements // `Clone`, even though it is of course not needed. For more info, see: @@ -49,13 +69,13 @@ impl Clone for Rc { impl PartialEq for Rc { fn eq(&self, other: &Self) -> bool { - rc::Rc::ptr_eq(&self.0, &other.0) + sync::Arc::ptr_eq(&self.0, &other.0) } } impl Hash for Rc { fn hash(&self, state: &mut H) { - rc::Rc::as_ptr(&self.0).hash(state); + sync::Arc::as_ptr(&self.0).hash(state); } } @@ -78,7 +98,7 @@ impl AsRef for Rc { // Implements `From` for every `U` that can be converted into an `rc::Rc` impl From for Rc where - rc::Rc: From, + sync::Arc: From, { fn from(inner: U) -> Self { Self(inner.into()) @@ -108,11 +128,11 @@ impl Rc { /// Constructs a new `Rc`. pub fn new(value: T) -> Self { #[allow(clippy::disallowed_methods)] - Self(rc::Rc::new(value)) + Self(sync::Arc::new(value)) } /// Similar to [`std::rc::Rc::strong_count`]. pub fn strong_count(this: &Self) -> usize { - rc::Rc::strong_count(&this.0) + sync::Arc::strong_count(&this.0) } } diff --git a/carcara/src/ast/substitution.rs b/carcara/src/ast/substitution.rs index 532f2d8b..1d7fdb04 100644 --- a/carcara/src/ast/substitution.rs +++ b/carcara/src/ast/substitution.rs @@ -1,7 +1,7 @@ //! Algorithms for creating and applying capture-avoiding substitutions over terms. use super::{BindingList, Rc, SortedVar, Term, TermPool}; -use ahash::{AHashMap, AHashSet}; +use indexmap::{IndexMap, IndexSet}; use thiserror::Error; /// The error type for errors when constructing or applying substitutions. @@ -36,27 +36,27 @@ type SubstitutionResult = Result; /// actually be `(forall ((y' Int)) (= y y'))`. pub struct Substitution { /// The substitution's mappings. - pub(crate) map: AHashMap, Rc>, + pub(crate) map: IndexMap, Rc>, /// The variables that should be renamed to preserve capture-avoidance, if they are bound by a /// binder term. - should_be_renamed: Option>>, - cache: AHashMap, Rc>, + should_be_renamed: Option>>, + cache: IndexMap, Rc>, } impl Substitution { /// Constructs an empty substitution. pub fn empty() -> Self { Self { - map: AHashMap::new(), + map: IndexMap::new(), should_be_renamed: None, - cache: AHashMap::new(), + cache: IndexMap::new(), } } /// Constructs a singleton substitution mapping `x` to `t`. This returns an error if the sorts /// of the given terms are not the same, or if `x` is not a variable term. - pub fn single(pool: &mut TermPool, x: Rc, t: Rc) -> SubstitutionResult { + pub fn single(pool: &mut dyn TermPool, x: Rc, t: Rc) -> SubstitutionResult { let mut this = Self::empty(); this.insert(pool, x, t)?; Ok(this) @@ -65,8 +65,11 @@ impl Substitution { /// Constructs a new substitution from an arbitrary mapping of terms to other terms. This /// returns an error if any term in the left-hand side is not a variable, or if any term is /// mapped to a term of a different sort. - pub fn new(pool: &mut TermPool, map: AHashMap, Rc>) -> SubstitutionResult { - for (k, v) in map.iter() { + pub fn new( + pool: &mut dyn TermPool, + map: IndexMap, Rc>, + ) -> SubstitutionResult { + for (k, v) in &map { if !k.is_var() { return Err(SubstitutionError::NotAVariable(k.clone())); } @@ -78,7 +81,7 @@ impl Substitution { Ok(Self { map, should_be_renamed: None, - cache: AHashMap::new(), + cache: IndexMap::new(), }) } @@ -91,7 +94,7 @@ impl Substitution { /// the sorts of the given terms are not the same, or if `x` is not a variable term. pub(crate) fn insert( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, x: Rc, t: Rc, ) -> SubstitutionResult<()> { @@ -109,7 +112,7 @@ impl Substitution { if let Some(should_be_renamed) = &mut self.should_be_renamed { if x != t { - should_be_renamed.extend(pool.free_vars(&t).iter().cloned()); + should_be_renamed.extend(pool.free_vars(&t)); if x.is_var() { should_be_renamed.insert(x.clone()); } @@ -122,7 +125,7 @@ impl Substitution { /// Computes which binder variables will need to be renamed, and stores the result in /// `self.should_be_renamed`. - fn compute_should_be_renamed(&mut self, pool: &mut TermPool) { + fn compute_should_be_renamed(&mut self, pool: &mut dyn TermPool) { if self.should_be_renamed.is_some() { return; } @@ -143,12 +146,12 @@ impl Substitution { // // See https://en.wikipedia.org/wiki/Lambda_calculus#Capture-avoiding_substitutions for // more details. - let mut should_be_renamed = AHashSet::new(); - for (x, t) in self.map.iter() { + let mut should_be_renamed = IndexSet::new(); + for (x, t) in &self.map { if x == t { continue; // We ignore reflexive substitutions } - should_be_renamed.extend(pool.free_vars(t).iter().cloned()); + should_be_renamed.extend(pool.free_vars(t).into_iter()); if x.is_var() { should_be_renamed.insert(x.clone()); } @@ -157,7 +160,7 @@ impl Substitution { } /// Applies the substitution to `term`, and returns the result as a new term. - pub fn apply(&mut self, pool: &mut TermPool, term: &Rc) -> Rc { + pub fn apply(&mut self, pool: &mut dyn TermPool, term: &Rc) -> Rc { macro_rules! apply_to_sequence { ($sequence:expr) => { $sequence @@ -201,7 +204,7 @@ impl Substitution { Term::Lambda(b, t) => { self.apply_to_binder(pool, term, b.as_ref(), t, true, Term::Lambda) } - Term::Terminal(_) | Term::Sort(_) => term.clone(), + Term::Const(_) | Term::Var(..) | Term::Sort(_) => term.clone(), }; // Since frequently a term will have more than one identical subterms, we insert the @@ -214,7 +217,7 @@ impl Substitution { fn can_skip_instead_of_renaming( &self, - pool: &mut TermPool, + pool: &mut dyn TermPool, binding_list: &[SortedVar], ) -> bool { // Note: this method assumes that `binding_list` is a "sort" binding list. "Value" lists add @@ -245,7 +248,7 @@ impl Substitution { /// binder is a `let` or `lambda` term, `is_value_list` should be true. fn apply_to_binder) -> Term>( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, original_term: &Rc, binding_list: &[SortedVar], inner: &Rc, @@ -289,19 +292,19 @@ impl Substitution { /// a `let` or `lambda` term, `is_value_list` should be true. fn rename_binding_list( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, binding_list: &[SortedVar], is_value_list: bool, ) -> (BindingList, Self) { let mut new_substitution = Self::empty(); - let mut new_vars = AHashSet::new(); + let mut new_vars = IndexSet::new(); let new_binding_list = binding_list .iter() .map(|(var, value)| { // If the binding list is a "sort" binding list, then `value` will be the variable's // sort. Otherwise, we need to get the sort of `value` let sort = if is_value_list { - pool.add(Term::Sort(pool.sort(value).clone())) + pool.sort(value) } else { value.clone() }; @@ -350,12 +353,11 @@ impl Substitution { #[cfg(test)] mod tests { use super::*; - use crate::parser::*; + use crate::{ast::PrimitivePool, parser::*}; fn run_test(definitions: &str, original: &str, x: &str, t: &str, result: &str) { - let mut pool = TermPool::new(); - let mut parser = - Parser::new(&mut pool, definitions.as_bytes(), true, false, false).unwrap(); + let mut pool = PrimitivePool::new(); + let mut parser = Parser::new(&mut pool, Config::new(), definitions.as_bytes()).unwrap(); parser.parse_problem().unwrap(); let [original, x, t, result] = [original, x, t, result].map(|s| { @@ -363,7 +365,7 @@ mod tests { parser.parse_term().unwrap() }); - let mut map = AHashMap::new(); + let mut map = IndexMap::new(); map.insert(x, t); let got = Substitution::new(&mut pool, map) diff --git a/carcara/src/ast/tests.rs b/carcara/src/ast/tests.rs index 5343c524..4b034429 100644 --- a/carcara/src/ast/tests.rs +++ b/carcara/src/ast/tests.rs @@ -1,18 +1,18 @@ -use crate::{ast::TermPool, parser::tests::parse_terms}; -use ahash::AHashSet; +use crate::{ + ast::{pool::PrimitivePool, TermPool}, + parser::tests::parse_terms, +}; +use indexmap::IndexSet; #[test] fn test_free_vars() { fn run_tests(definitions: &str, cases: &[(&str, &[&str])]) { for &(term, expected) in cases { - let mut pool = TermPool::new(); + let mut pool = PrimitivePool::new(); let [root] = parse_terms(&mut pool, definitions, [term]); - let expected: AHashSet<_> = expected.iter().copied().collect(); - let got: AHashSet<_> = pool - .free_vars(&root) - .iter() - .map(|t| t.as_var().unwrap()) - .collect(); + let expected: IndexSet<_> = expected.iter().copied().collect(); + let set = pool.free_vars(&root); + let got: IndexSet<_> = set.iter().map(|t| t.as_var().unwrap()).collect(); assert_eq!(expected, got); } @@ -37,23 +37,23 @@ fn test_free_vars() { } #[test] -fn test_deep_eq() { +fn test_polyeq() { enum TestType { - ModReordering, + Polyeq, AlphaEquiv, } fn run_tests(definitions: &str, cases: &[(&str, &str)], test_type: TestType) { - let mut pool = TermPool::new(); + let mut pool = PrimitivePool::new(); for (a, b) in cases { let [a, b] = parse_terms(&mut pool, definitions, [a, b]); let mut time = std::time::Duration::ZERO; match test_type { - TestType::ModReordering => { - assert!(super::deep_eq::deep_eq(&a, &b, &mut time)); + TestType::Polyeq => { + assert!(super::polyeq::polyeq(&a, &b, &mut time)); } TestType::AlphaEquiv => { - assert!(super::deep_eq::are_alpha_equivalent(&a, &b, &mut time)); + assert!(super::polyeq::alpha_equiv(&a, &b, &mut time)); } } } @@ -77,7 +77,7 @@ fn test_deep_eq() { "(ite (= b a) (= (+ x y) x) (and p (not (= y x))))", ), ], - TestType::ModReordering, + TestType::Polyeq, ); run_tests( definitions, diff --git a/carcara/src/benchmarking/metrics.rs b/carcara/src/benchmarking/metrics.rs index 0afe3a61..aa5be487 100644 --- a/carcara/src/benchmarking/metrics.rs +++ b/carcara/src/benchmarking/metrics.rs @@ -142,7 +142,7 @@ where } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct OnlineMetrics { total: T, count: usize, diff --git a/carcara/src/benchmarking/mod.rs b/carcara/src/benchmarking/mod.rs index fc36fb35..dbcd88eb 100644 --- a/carcara/src/benchmarking/mod.rs +++ b/carcara/src/benchmarking/mod.rs @@ -4,15 +4,15 @@ mod tests; pub use metrics::*; -use ahash::AHashMap; -use std::{fmt, io, time::Duration}; +use indexmap::{map::Entry, IndexMap, IndexSet}; +use std::{fmt, hash::Hash, io, sync::Arc, time::Duration}; -fn combine_map(mut a: AHashMap, b: AHashMap) -> AHashMap +fn combine_map(mut a: IndexMap, b: IndexMap) -> IndexMap where + S: Eq + Hash, V: MetricsUnit, M: Metrics + Default, { - use std::collections::hash_map::Entry; for (k, v) in b { match a.entry(k) { Entry::Occupied(mut e) => { @@ -49,32 +49,32 @@ pub struct RunMeasurement { pub parsing: Duration, pub checking: Duration, pub elaboration: Duration, + pub scheduling: Duration, pub total: Duration, - pub deep_eq: Duration, + pub polyeq: Duration, pub assume: Duration, pub assume_core: Duration, } -// Higher kinded types would be very useful here. Ideally, I would like `BenchmarkResults` to be -// generic on any kind that implements `Metrics`, like `OnlineMetrics` or `OfflineMetrics`. -#[derive(Debug, Default)] -pub struct BenchmarkResults { - pub parsing: ByRun, - pub checking: ByRun, - pub elaborating: ByRun, - pub total_accounted_for: ByRun, - pub total: ByRun, - pub step_time: ByStep, - pub step_time_by_file: AHashMap, - pub step_time_by_rule: AHashMap, - - pub deep_eq_time: ByRun, - pub deep_eq_time_ratio: ByRunF64, - pub assume_time: ByRun, - pub assume_time_ratio: ByRunF64, - pub assume_core_time: ByRun, - - pub deep_eq_depths: ByDeepEq, +#[derive(Debug, Default, Clone)] +pub struct OnlineBenchmarkResults { + pub parsing: OnlineMetrics, + pub checking: OnlineMetrics, + pub elaborating: OnlineMetrics, + pub scheduling: OnlineMetrics, + pub total_accounted_for: OnlineMetrics, + pub total: OnlineMetrics, + pub step_time: OnlineMetrics, + pub step_time_by_file: IndexMap>, + pub step_time_by_rule: IndexMap>, + + pub polyeq_time: OnlineMetrics, + pub polyeq_time_ratio: OnlineMetrics, + pub assume_time: OnlineMetrics, + pub assume_time_ratio: OnlineMetrics, + pub assume_core_time: OnlineMetrics, + + pub polyeq_depths: OnlineMetrics<(), usize>, pub num_assumes: usize, pub num_easy_assumes: usize, @@ -82,27 +82,7 @@ pub struct BenchmarkResults { pub had_error: bool, } -pub type OnlineBenchmarkResults = BenchmarkResults< - OnlineMetrics, - OnlineMetrics, - OnlineMetrics, - OnlineMetrics<(), usize>, ->; - -pub type OfflineBenchmarkResults = BenchmarkResults< - OfflineMetrics, - OfflineMetrics, - OfflineMetrics, - OfflineMetrics<(), usize>, ->; - -impl BenchmarkResults -where - ByRun: Metrics + Default, - ByStep: Metrics + Default, - ByRunF64: Metrics + Default, - ByDeepEq: Metrics<(), usize> + Default, -{ +impl OnlineBenchmarkResults { pub fn new() -> Self { Default::default() } @@ -113,50 +93,195 @@ where } /// The time per run to completely parse the proof. - pub fn parsing(&self) -> &ByRun { + pub fn parsing(&self) -> &OnlineMetrics { &self.parsing } /// The time per run to check all the steps in the proof. - pub fn checking(&self) -> &ByRun { + pub fn checking(&self) -> &OnlineMetrics { &self.checking } /// The time per run to elaborate the proof. - pub fn elaborating(&self) -> &ByRun { + pub fn elaborating(&self) -> &OnlineMetrics { &self.elaborating } + /// The time per run to schedule the threads tasks. + pub fn scheduling(&self) -> &OnlineMetrics { + &self.scheduling + } + /// The combined time per run to parse, check, and elaborate all the steps in the proof. - pub fn total_accounted_for(&self) -> &ByRun { + pub fn total_accounted_for(&self) -> &OnlineMetrics { &self.total_accounted_for } /// The total time spent per run. Should be pretty similar to `total_accounted_for`. - pub fn total(&self) -> &ByRun { + pub fn total(&self) -> &OnlineMetrics { &self.total } /// The time spent checking each step. - pub fn step_time(&self) -> &ByStep { + pub fn step_time(&self) -> &OnlineMetrics { &self.step_time } /// For each file, the time spent checking each step in the file. - pub fn step_time_by_file(&self) -> &AHashMap { + pub fn step_time_by_file(&self) -> &IndexMap> { &self.step_time_by_file } /// For each rule, the time spent checking each step that uses that rule. - pub fn step_time_by_rule(&self) -> &AHashMap { + pub fn step_time_by_rule(&self) -> &IndexMap> { &self.step_time_by_rule } + + /// Prints the benchmark results + pub fn print(&self, sort_by_total: bool) { + let [parsing, checking, elaborating, scheduling, accounted_for, total, assume_time, assume_core_time, polyeq_time] = + [ + self.parsing(), + self.checking(), + self.elaborating(), + self.scheduling(), + self.total_accounted_for(), + self.total(), + &self.assume_time, + &self.assume_core_time, + &self.polyeq_time, + ] + .map(|m| { + if sort_by_total { + format!("{:#}", m) + } else { + format!("{}", m) + } + }); + + println!("parsing: {}", parsing); + println!("checking: {}", checking); + if !elaborating.is_empty() { + println!("elaborating: {}", elaborating); + } + println!("scheduling: {}", scheduling); + + println!( + "on assume: {} ({:.02}% of checking time)", + assume_time, + 100.0 * self.assume_time.mean().as_secs_f64() / self.checking().mean().as_secs_f64(), + ); + println!("on assume (core): {}", assume_core_time); + println!("assume ratio: {}", self.assume_time_ratio); + println!( + "on polyeq: {} ({:.02}% of checking time)", + polyeq_time, + 100.0 * self.polyeq_time.mean().as_secs_f64() / self.checking().mean().as_secs_f64(), + ); + println!("polyeq ratio: {}", self.polyeq_time_ratio); + + println!("total accounted for: {}", accounted_for); + println!("total: {}", total); + + let data_by_rule = self.step_time_by_rule(); + let mut data_by_rule: Vec<_> = data_by_rule.iter().collect(); + data_by_rule.sort_by_key(|(_, m)| if sort_by_total { m.total() } else { m.mean() }); + + println!("by rule:"); + for (rule, data) in data_by_rule { + print!(" {: <18}", rule); + if sort_by_total { + println!("{:#}", data); + } else { + println!("{}", data); + } + } + + println!("worst cases:"); + if !self.step_time().is_empty() { + let worst_step = self.step_time().max(); + println!(" step: {} ({:?})", worst_step.0, worst_step.1); + } + + let worst_file_parsing = self.parsing().max(); + println!( + " file (parsing): {} ({:?})", + worst_file_parsing.0 .0, worst_file_parsing.1 + ); + + let worst_file_checking = self.checking().max(); + println!( + " file (checking): {} ({:?})", + worst_file_checking.0 .0, worst_file_checking.1 + ); + + let worst_file_assume = self.assume_time_ratio.max(); + println!( + " file (assume): {} ({:.04}%)", + worst_file_assume.0 .0, + worst_file_assume.1 * 100.0 + ); + + let worst_file_polyeq = self.polyeq_time_ratio.max(); + println!( + " file (polyeq): {} ({:.04}%)", + worst_file_polyeq.0 .0, + worst_file_polyeq.1 * 100.0 + ); + + let worst_file_total = self.total().max(); + println!( + " file overall: {} ({:?})", + worst_file_total.0 .0, worst_file_total.1 + ); + + let num_hard_assumes = self.num_assumes - self.num_easy_assumes; + let percent_easy = (self.num_easy_assumes as f64) * 100.0 / (self.num_assumes as f64); + let percent_hard = (num_hard_assumes as f64) * 100.0 / (self.num_assumes as f64); + println!(" number of assumes: {}", self.num_assumes); + println!( + " (easy): {} ({:.02}%)", + self.num_easy_assumes, percent_easy + ); + println!( + " (hard): {} ({:.02}%)", + num_hard_assumes, percent_hard + ); + + let depths = &self.polyeq_depths; + if !depths.is_empty() { + println!(" max polyeq depth: {}", depths.max().1); + println!(" total polyeq depth: {}", depths.total()); + println!(" number of polyeq checks: {}", depths.count()); + println!(" mean depth: {:.4}", depths.mean()); + println!( + "standard deviation of depth: {:.4}", + depths.standard_deviation() + ); + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct InternedStepId { + pub(crate) file: Arc, + pub(crate) step_id: Arc, + pub(crate) rule: Arc, } +impl fmt::Display for InternedStepId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}:{} ({})", self.file, self.step_id, self.rule) + } +} + +type InternedRunId = (Arc, usize); + #[derive(Default)] pub struct CsvBenchmarkResults { - runs: AHashMap, - step_time_by_rule: AHashMap>, + strings: IndexSet>, + runs: IndexMap, + step_time_by_rule: IndexMap, OfflineMetrics>, is_holey: bool, num_errors: usize, } @@ -174,6 +299,17 @@ impl CsvBenchmarkResults { self.num_errors } + fn intern(&mut self, s: &str) -> Arc { + match self.strings.get(s) { + Some(interned) => interned.clone(), + None => { + let result: Arc = Arc::from(s); + self.strings.insert(result.clone()); + result + } + } + } + pub fn write_csv( self, runs_dest: &mut dyn io::Write, @@ -184,18 +320,18 @@ impl CsvBenchmarkResults { } fn write_runs_csv( - data: AHashMap, + data: IndexMap, dest: &mut dyn io::Write, ) -> io::Result<()> { writeln!( dest, "proof_file,run_id,parsing,checking,elaboration,total_accounted_for,\ - total,deep_eq,deep_eq_ratio,assume,assume_ratio" + total,polyeq,polyeq_ratio,assume,assume_ratio" )?; for (id, m) in data { let total_accounted_for = m.parsing + m.checking; - let deep_eq_ratio = m.deep_eq.as_secs_f64() / m.checking.as_secs_f64(); + let polyeq_ratio = m.polyeq.as_secs_f64() / m.checking.as_secs_f64(); let assume_ratio = m.assume.as_secs_f64() / m.checking.as_secs_f64(); writeln!( dest, @@ -207,8 +343,8 @@ impl CsvBenchmarkResults { m.elaboration.as_nanos(), total_accounted_for.as_nanos(), m.total.as_nanos(), - m.deep_eq.as_nanos(), - deep_eq_ratio, + m.polyeq.as_nanos(), + polyeq_ratio, m.assume.as_nanos(), assume_ratio, )?; @@ -218,7 +354,7 @@ impl CsvBenchmarkResults { } fn write_by_rule_csv( - data: AHashMap>, + data: IndexMap, OfflineMetrics>, dest: &mut dyn io::Write, ) -> io::Result<()> { let mut data: Vec<_> = data.into_iter().collect(); @@ -252,7 +388,7 @@ impl CsvBenchmarkResults { pub trait CollectResults { fn add_step_measurement(&mut self, file: &str, step_id: &str, rule: &str, time: Duration); fn add_assume_measurement(&mut self, file: &str, id: &str, is_easy: bool, time: Duration); - fn add_deep_eq_depth(&mut self, depth: usize); + fn add_polyeq_depth(&mut self, depth: usize); fn add_run_measurement(&mut self, id: &RunId, measurement: RunMeasurement); fn register_holey(&mut self); fn register_error(&mut self, error: &crate::Error); @@ -262,14 +398,7 @@ pub trait CollectResults { Self: Sized; } -impl CollectResults - for BenchmarkResults -where - ByRun: Metrics + Default, - ByStep: Metrics + Default, - ByRunF64: Metrics + Default, - ByDeepEq: Metrics<(), usize> + Default, -{ +impl CollectResults for OnlineBenchmarkResults { fn add_step_measurement(&mut self, file: &str, step_id: &str, rule: &str, time: Duration) { let file = file.to_owned(); let rule = rule.to_owned(); @@ -295,8 +424,8 @@ where self.add_step_measurement(file, id, "assume", time); } - fn add_deep_eq_depth(&mut self, depth: usize) { - self.deep_eq_depths.add_sample(&(), depth); + fn add_polyeq_depth(&mut self, depth: usize) { + self.polyeq_depths.add_sample(&(), depth); } fn add_run_measurement(&mut self, id: &RunId, measurement: RunMeasurement) { @@ -304,8 +433,9 @@ where parsing, checking, elaboration, + scheduling, total, - deep_eq, + polyeq, assume, assume_core, } = measurement; @@ -313,16 +443,17 @@ where self.parsing.add_sample(id, parsing); self.checking.add_sample(id, checking); self.elaborating.add_sample(id, elaboration); + self.scheduling.add_sample(id, scheduling); self.total_accounted_for.add_sample(id, parsing + checking); self.total.add_sample(id, total); - self.deep_eq_time.add_sample(id, deep_eq); + self.polyeq_time.add_sample(id, polyeq); self.assume_time.add_sample(id, assume); self.assume_core_time.add_sample(id, assume_core); - let deep_eq_ratio = deep_eq.as_secs_f64() / checking.as_secs_f64(); + let polyeq_ratio = polyeq.as_secs_f64() / checking.as_secs_f64(); let assume_ratio = assume.as_secs_f64() / checking.as_secs_f64(); - self.deep_eq_time_ratio.add_sample(id, deep_eq_ratio); + self.polyeq_time_ratio.add_sample(id, polyeq_ratio); self.assume_time_ratio.add_sample(id, assume_ratio); } @@ -331,19 +462,20 @@ where parsing: a.parsing.combine(b.parsing), checking: a.checking.combine(b.checking), elaborating: a.elaborating.combine(b.elaborating), + scheduling: a.scheduling.combine(b.scheduling), total_accounted_for: a.total_accounted_for.combine(b.total_accounted_for), total: a.total.combine(b.total), step_time: a.step_time.combine(b.step_time), step_time_by_file: combine_map(a.step_time_by_file, b.step_time_by_file), step_time_by_rule: combine_map(a.step_time_by_rule, b.step_time_by_rule), - deep_eq_time: a.deep_eq_time.combine(b.deep_eq_time), - deep_eq_time_ratio: a.deep_eq_time_ratio.combine(b.deep_eq_time_ratio), + polyeq_time: a.polyeq_time.combine(b.polyeq_time), + polyeq_time_ratio: a.polyeq_time_ratio.combine(b.polyeq_time_ratio), assume_time: a.assume_time.combine(b.assume_time), assume_time_ratio: a.assume_time_ratio.combine(b.assume_time_ratio), assume_core_time: a.assume_core_time.combine(b.assume_core_time), - deep_eq_depths: a.deep_eq_depths.combine(b.deep_eq_depths), + polyeq_depths: a.polyeq_depths.combine(b.polyeq_depths), num_assumes: a.num_assumes + b.num_assumes, num_easy_assumes: a.num_easy_assumes + b.num_easy_assumes, is_holey: a.is_holey || b.is_holey, @@ -362,13 +494,14 @@ where impl CollectResults for CsvBenchmarkResults { fn add_step_measurement(&mut self, file: &str, step_id: &str, rule: &str, time: Duration) { - let id = StepId { - file: file.into(), - step_id: step_id.into(), - rule: rule.into(), + let rule = self.intern(rule); + let id = InternedStepId { + file: self.intern(file), + step_id: self.intern(step_id), + rule: rule.clone(), }; self.step_time_by_rule - .entry(rule.to_owned()) + .entry(rule) .or_default() .add_sample(&id, time); } @@ -377,10 +510,11 @@ impl CollectResults for CsvBenchmarkResults { self.add_step_measurement(file, id, "assume", time); } - fn add_deep_eq_depth(&mut self, _: usize) {} + fn add_polyeq_depth(&mut self, _: usize) {} - fn add_run_measurement(&mut self, id: &RunId, measurement: RunMeasurement) { - self.runs.insert(id.clone(), measurement); + fn add_run_measurement(&mut self, (file, i): &RunId, measurement: RunMeasurement) { + let id = (self.intern(file), *i); + self.runs.insert(id, measurement); } fn register_holey(&mut self) { diff --git a/carcara/src/checker/context.rs b/carcara/src/checker/context.rs deleted file mode 100644 index 65fed185..00000000 --- a/carcara/src/checker/context.rs +++ /dev/null @@ -1,157 +0,0 @@ -use crate::ast::*; -use ahash::AHashSet; - -pub struct Context { - pub mappings: Vec<(Rc, Rc)>, - pub bindings: AHashSet, - pub cumulative_substitution: Option, -} - -#[derive(Default)] -pub struct ContextStack { - stack: Vec, - num_cumulative_calculated: usize, -} - -impl ContextStack { - pub fn new() -> Self { - Default::default() - } - - pub fn len(&self) -> usize { - self.stack.len() - } - - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - pub fn last(&self) -> Option<&Context> { - self.stack.last() - } - - pub fn last_mut(&mut self) -> Option<&mut Context> { - self.stack.last_mut() - } - - pub fn push( - &mut self, - pool: &mut TermPool, - assignment_args: &[(String, Rc)], - variable_args: &[SortedVar], - ) -> Result<(), SubstitutionError> { - // Since some rules (like `refl`) need to apply substitutions until a fixed point, we - // precompute these substitutions into a separate hash map. This assumes that the assignment - // arguments are in the correct order. - let mut substitution = Substitution::empty(); - let mut substitution_until_fixed_point = Substitution::empty(); - - // We build the `substitution_until_fixed_point` hash map from the bottom up, by using the - // substitutions already introduced to transform the result of a new substitution before - // inserting it into the hash map. So for instance, if the substitutions are `(:= y z)` and - // `(:= x (f y))`, we insert the first substitution, and then, when introducing the second, - // we use the current state of the hash map to transform `(f y)` into `(f z)`. The - // resulting hash map will then contain `(:= y z)` and `(:= x (f z))` - for (var, value) in assignment_args.iter() { - let sort = Term::Sort(pool.sort(value).clone()); - let var_term = Term::var(var, pool.add(sort)); - let var_term = pool.add(var_term); - substitution.insert(pool, var_term.clone(), value.clone())?; - let new_value = substitution_until_fixed_point.apply(pool, value); - substitution_until_fixed_point.insert(pool, var_term, new_value)?; - } - - let mappings = assignment_args - .iter() - .map(|(var, value)| { - let sort = Term::Sort(pool.sort(value).clone()); - let var_term = (var.clone(), pool.add(sort)).into(); - (pool.add(var_term), value.clone()) - }) - .collect(); - let bindings = variable_args.iter().cloned().collect(); - self.stack.push(Context { - mappings, - bindings, - cumulative_substitution: None, - }); - Ok(()) - } - - pub fn pop(&mut self) { - self.stack.pop(); - self.num_cumulative_calculated = - std::cmp::min(self.num_cumulative_calculated, self.stack.len()); - } - - fn catch_up_cumulative(&mut self, pool: &mut TermPool, up_to: usize) { - for i in self.num_cumulative_calculated..std::cmp::max(up_to + 1, self.len()) { - let simultaneous = build_simultaneous_substitution(pool, &self.stack[i].mappings).map; - let mut cumulative_substitution = simultaneous.clone(); - - if i > 0 { - if let Some(previous_context) = self.stack.get(i - 1) { - let previous_substitution = - previous_context.cumulative_substitution.as_ref().unwrap(); - - for (k, v) in previous_substitution.map.iter() { - let value = match simultaneous.get(v) { - Some(new_value) => new_value, - None => v, - }; - cumulative_substitution.insert(k.clone(), value.clone()); - } - } - } - self.stack[i].cumulative_substitution = - Some(Substitution::new(pool, cumulative_substitution).unwrap()); - self.num_cumulative_calculated = i + 1; - } - } - - fn get_substitution(&mut self, pool: &mut TermPool, index: usize) -> &mut Substitution { - assert!(index < self.len()); - self.catch_up_cumulative(pool, index); - self.stack[index].cumulative_substitution.as_mut().unwrap() - } - - pub fn apply_previous(&mut self, pool: &mut TermPool, term: &Rc) -> Rc { - if self.len() < 2 { - term.clone() - } else { - self.get_substitution(pool, self.len() - 2) - .apply(pool, term) - } - } - - pub fn apply(&mut self, pool: &mut TermPool, term: &Rc) -> Rc { - if self.is_empty() { - term.clone() - } else { - self.get_substitution(pool, self.len() - 1) - .apply(pool, term) - } - } -} - -fn build_simultaneous_substitution( - pool: &mut TermPool, - mappings: &[(Rc, Rc)], -) -> Substitution { - let mut result = Substitution::empty(); - - // We build the simultaneous substitution from the bottom up, by using the mappings already - // introduced to transform the result of a new mapping before inserting it into the - // substitution. So for instance, if the mappings are `y -> z` and `x -> (f y)`, we insert the - // first mapping, and then, when introducing the second, we use the current state of the - // substitution to transform `(f y)` into `(f z)`. The result will then contain `y -> z` and - // `x -> (f z)`. - for (var, value) in mappings { - let new_value = result.apply(pool, value); - - // We can unwrap here safely because, by construction, the sort of `var` is the - // same as the sort of `new_value` - result.insert(pool, var.clone(), new_value).unwrap(); - } - result -} diff --git a/carcara/src/checker/error.rs b/carcara/src/checker/error.rs index 2565196e..3213b792 100644 --- a/carcara/src/checker/error.rs +++ b/carcara/src/checker/error.rs @@ -248,29 +248,29 @@ pub enum LinearArithmeticError { #[derive(Debug, Error)] pub enum LiaGenericError { - #[error("failed to spawn cvc5 process")] - FailedSpawnCvc5(io::Error), + #[error("failed to spawn solver process")] + FailedSpawnSolver(io::Error), - #[error("failed to write to cvc5 stdin")] - FailedWriteToCvc5Stdin(io::Error), + #[error("failed to write to solver stdin")] + FailedWriteToSolverStdin(io::Error), - #[error("error while waiting for cvc5 to exit")] - FailedWaitForCvc5(io::Error), + #[error("error while waiting for solver to exit")] + FailedWaitForSolver(io::Error), - #[error("cvc5 gave invalid output")] - Cvc5GaveInvalidOutput, + #[error("solver gave invalid output")] + SolverGaveInvalidOutput, - #[error("cvc5 output not unsat")] - Cvc5OutputNotUnsat, + #[error("solver output not unsat")] + OutputNotUnsat, - #[error("cvc5 timed out when solving problem")] - Cvc5Timeout, + #[error("solver timed out when solving problem")] + SolverTimeout, #[error( - "cvc5 returned non-zero exit code: {}", + "solver returned non-zero exit code: {}", if let Some(i) = .0 { format!("{}", i) } else { "none".to_owned() } )] - Cvc5NonZeroExitCode(Option), + NonZeroExitCode(Option), #[error("error in inner proof: {0}")] InnerProofError(Box), @@ -282,6 +282,12 @@ pub enum SubproofError { #[error("discharge must be 'assume' command: '{0}'")] DischargeMustBeAssume(String), + #[error("local assumption '{0}' was not discharged")] + LocalAssumeNotDischarged(String), + + #[error("only the `subproof` rule may discharge local assumptions")] + DischargeInWrongRule, + #[error("binding '{0}' appears as free variable in phi")] BindBindingIsFreeVarInPhi(String), @@ -330,7 +336,7 @@ impl<'a> fmt::Display for DisplayLinearComb<'a> { 1 => write_var(f, vars.iter().next().unwrap()), _ => { write!(f, "(+")?; - for var in vars.iter() { + for var in vars { write!(f, " ")?; write_var(f, var)?; } diff --git a/carcara/src/checker/lia_generic.rs b/carcara/src/checker/lia_generic.rs index b71e24a7..21233505 100644 --- a/carcara/src/checker/lia_generic.rs +++ b/carcara/src/checker/lia_generic.rs @@ -1,6 +1,6 @@ use super::*; -use crate::{checker::error::LiaGenericError, parser}; -use ahash::AHashMap; +use crate::{checker::error::LiaGenericError, parser, LiaGenericOptions}; +use indexmap::IndexMap; use std::{ io::{BufRead, Write}, process::{Command, Stdio}, @@ -24,18 +24,19 @@ fn get_problem_string(conclusion: &[Rc], prelude: &ProblemPrelude) -> Stri problem } -pub fn lia_generic( - pool: &mut TermPool, +pub fn lia_generic_single_thread( + pool: &mut PrimitivePool, conclusion: &[Rc], prelude: &ProblemPrelude, elaborator: Option<&mut Elaborator>, root_id: &str, + options: &LiaGenericOptions, ) -> bool { let problem = get_problem_string(conclusion, prelude); - let commands = match get_cvc5_proof(pool, problem) { + let commands = match get_solver_proof(pool, problem, options) { Ok(c) => c, Err(e) => { - log::warn!("failed to check `lia_generic` step using cvc5: {}", e); + log::warn!("failed to check `lia_generic` step: {}", e); if let Some(elaborator) = elaborator { elaborator.unchanged(conclusion); } @@ -44,46 +45,57 @@ pub fn lia_generic( }; if let Some(elaborator) = elaborator { - insert_cvc5_proof(pool, elaborator, commands, conclusion, root_id); + insert_solver_proof(pool, elaborator, commands, conclusion, root_id); } false } -fn get_cvc5_proof( - pool: &mut TermPool, +pub fn lia_generic_multi_thread( + conclusion: &[Rc], + prelude: &ProblemPrelude, + options: &LiaGenericOptions, +) -> bool { + let mut pool = PrimitivePool::new(); + let problem = get_problem_string(conclusion, prelude); + if let Err(e) = get_solver_proof(&mut pool, problem, options) { + log::warn!("failed to check `lia_generic` step using: {}", e); + true + } else { + false + } +} + +fn get_solver_proof( + pool: &mut PrimitivePool, problem: String, + options: &LiaGenericOptions, ) -> Result, LiaGenericError> { - let mut cvc5 = Command::new("cvc5") - .args([ - "--tlimit=10000", - "--lang=smt2", - "--proof-format-mode=alethe", - "--proof-granularity=theory-rewrite", - "--proof-alethe-res-pivots", - ]) + let mut process = Command::new(options.solver.as_ref()) + .args(options.arguments.iter().map(AsRef::as_ref)) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .spawn() - .map_err(LiaGenericError::FailedSpawnCvc5)?; + .map_err(LiaGenericError::FailedSpawnSolver)?; - cvc5.stdin + process + .stdin .take() - .expect("failed to open cvc5 stdin") + .expect("failed to open solver stdin") .write_all(problem.as_bytes()) - .map_err(LiaGenericError::FailedWriteToCvc5Stdin)?; + .map_err(LiaGenericError::FailedWriteToSolverStdin)?; - let output = cvc5 + let output = process .wait_with_output() - .map_err(LiaGenericError::FailedWaitForCvc5)?; + .map_err(LiaGenericError::FailedWaitForSolver)?; if !output.status.success() { if let Ok(s) = std::str::from_utf8(&output.stderr) { - if s.contains("cvc5 interrupted by timeout.") { - return Err(LiaGenericError::Cvc5Timeout); + if s.contains("interrupted by timeout.") { + return Err(LiaGenericError::SolverTimeout); } } - return Err(LiaGenericError::Cvc5NonZeroExitCode(output.status.code())); + return Err(LiaGenericError::NonZeroExitCode(output.status.code())); } let mut proof = output.stdout.as_slice(); @@ -91,28 +103,28 @@ fn get_cvc5_proof( proof .read_line(&mut first_line) - .map_err(|_| LiaGenericError::Cvc5GaveInvalidOutput)?; + .map_err(|_| LiaGenericError::SolverGaveInvalidOutput)?; if first_line.trim_end() != "unsat" { - return Err(LiaGenericError::Cvc5OutputNotUnsat); + return Err(LiaGenericError::OutputNotUnsat); } - parse_and_check_cvc5_proof(pool, problem.as_bytes(), proof) + parse_and_check_solver_proof(pool, problem.as_bytes(), proof) .map_err(|e| LiaGenericError::InnerProofError(Box::new(e))) } -fn parse_and_check_cvc5_proof( - pool: &mut TermPool, +fn parse_and_check_solver_proof( + pool: &mut PrimitivePool, problem: &[u8], proof: &[u8], ) -> CarcaraResult> { - let mut parser = parser::Parser::new(pool, problem, true, false, true)?; + let mut parser = parser::Parser::new(pool, parser::Config::new(), problem)?; let (prelude, premises) = parser.parse_problem()?; parser.reset(proof)?; let commands = parser.parse_proof()?; let proof = Proof { premises, commands }; - ProofChecker::new(pool, Config::new(), prelude).check(&proof)?; + ProofChecker::new(pool, Config::new(), &prelude).check(&proof)?; Ok(proof.commands) } @@ -139,13 +151,13 @@ fn update_premises(commands: &mut [ProofCommand], delta: usize, root_id: &str) { } fn insert_missing_assumes( - pool: &mut TermPool, + pool: &mut PrimitivePool, elaborator: &mut Elaborator, conclusion: &[Rc], proof: &[ProofCommand], root_id: &str, ) -> (Vec>, usize) { - let mut count_map: AHashMap<&Rc, usize> = AHashMap::new(); + let mut count_map: IndexMap<&Rc, usize> = IndexMap::new(); for c in conclusion { *count_map.entry(c).or_default() += 1; } @@ -179,8 +191,8 @@ fn insert_missing_assumes( (all, num_added) } -fn insert_cvc5_proof( - pool: &mut TermPool, +fn insert_solver_proof( + pool: &mut PrimitivePool, elaborator: &mut Elaborator, mut commands: Vec, conclusion: &[Rc], @@ -195,7 +207,7 @@ fn insert_cvc5_proof( conclusion, &commands, // This is a bit ugly, but we have to add the ".added" to avoid colliding with the first few - // steps in the cvc5 proof + // steps in the solver proof &format!("{}.added", root_id), ); diff --git a/carcara/src/checker/mod.rs b/carcara/src/checker/mod.rs index 18f9f3d1..4f1b6aac 100644 --- a/carcara/src/checker/mod.rs +++ b/carcara/src/checker/mod.rs @@ -1,56 +1,58 @@ -mod context; -mod elaboration; pub mod error; mod lia_generic; +mod parallel; mod rules; -use crate::{ast::*, benchmarking::CollectResults, CarcaraResult, Error}; -use ahash::AHashSet; -use context::*; -use elaboration::Elaborator; -use error::CheckerError; +use crate::{ + ast::*, + benchmarking::{CollectResults, OnlineBenchmarkResults}, + elaborator::Elaborator, + CarcaraResult, Error, LiaGenericOptions, +}; +use error::{CheckerError, SubproofError}; +use indexmap::IndexSet; +pub use parallel::{scheduler::Scheduler, ParallelProofChecker}; use rules::{ElaborationRule, Premise, Rule, RuleArgs, RuleResult}; use std::{ fmt, time::{Duration, Instant}, }; -pub struct CheckerStatistics<'s> { +#[derive(Clone)] +pub struct CheckerStatistics<'s, CR: CollectResults + Send + Default> { pub file_name: &'s str, - pub elaboration_time: &'s mut Duration, - pub deep_eq_time: &'s mut Duration, - pub assume_time: &'s mut Duration, + pub elaboration_time: Duration, + pub polyeq_time: Duration, + pub assume_time: Duration, // This is the time to compare the `assume` term with the `assert` that matches it. That is, // this excludes the time spent searching for the correct `assert` premise. - pub assume_core_time: &'s mut Duration, - pub results: &'s mut dyn CollectResults, + pub assume_core_time: Duration, + pub results: CR, } -impl fmt::Debug for CheckerStatistics<'_> { +impl fmt::Debug for CheckerStatistics<'_, CR> { // Since `self.results` does not implement `Debug`, we can't just `#[derive(Debug)]` and instead // have to implement it manually, removing that field. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("CheckerStatistics") .field("file_name", &self.file_name) .field("elaboration_time", &self.elaboration_time) - .field("deep_eq_time", &self.deep_eq_time) + .field("polyeq_time", &self.polyeq_time) .field("assume_time", &self.assume_time) .field("assume_core_time", &self.assume_core_time) .finish() } } -#[derive(Debug, Default)] -pub struct Config<'c> { +#[derive(Debug, Default, Clone)] +pub struct Config { strict: bool, - skip_unknown_rules: bool, - is_running_test: bool, - statistics: Option>, - lia_via_cvc5: bool, + ignore_unknown_rules: bool, + lia_options: Option, } -impl<'c> Config<'c> { +impl Config { pub fn new() -> Self { Self::default() } @@ -60,26 +62,21 @@ impl<'c> Config<'c> { self } - pub fn skip_unknown_rules(mut self, value: bool) -> Self { - self.skip_unknown_rules = value; + pub fn ignore_unknown_rules(mut self, value: bool) -> Self { + self.ignore_unknown_rules = value; self } - pub fn lia_via_cvc5(mut self, value: bool) -> Self { - self.lia_via_cvc5 = value; - self - } - - pub fn statistics(mut self, value: CheckerStatistics<'c>) -> Self { - self.statistics = Some(value); + pub fn lia_options(mut self, value: impl Into>) -> Self { + self.lia_options = value.into(); self } } pub struct ProofChecker<'c> { - pool: &'c mut TermPool, - config: Config<'c>, - prelude: ProblemPrelude, + pool: &'c mut PrimitivePool, + config: Config, + prelude: &'c ProblemPrelude, context: ContextStack, elaborator: Option, reached_empty_clause: bool, @@ -87,7 +84,7 @@ pub struct ProofChecker<'c> { } impl<'c> ProofChecker<'c> { - pub fn new(pool: &'c mut TermPool, config: Config<'c>, prelude: ProblemPrelude) -> Self { + pub fn new(pool: &'c mut PrimitivePool, config: Config, prelude: &'c ProblemPrelude) -> Self { ProofChecker { pool, config, @@ -100,6 +97,25 @@ impl<'c> ProofChecker<'c> { } pub fn check(&mut self, proof: &Proof) -> CarcaraResult { + self.check_impl( + proof, + None::<&mut CheckerStatistics>, + ) + } + + pub fn check_with_stats( + &mut self, + proof: &Proof, + stats: &mut CheckerStatistics, + ) -> CarcaraResult { + self.check_impl(proof, Some(stats)) + } + + fn check_impl( + &mut self, + proof: &Proof, + mut stats: Option<&mut CheckerStatistics>, + ) -> CarcaraResult { // Similarly to the parser, to avoid stack overflows in proofs with many nested subproofs, // we check the subproofs iteratively, instead of recursively let mut iter = proof.iter(); @@ -119,7 +135,7 @@ impl<'c> ProofChecker<'c> { } else { None }; - self.check_step(step, previous_command, &iter) + self.check_step(step, previous_command, &iter, &mut stats) .map_err(|e| Error::Checker { inner: e, rule: step.rule.clone(), @@ -144,8 +160,14 @@ impl<'c> ProofChecker<'c> { let time = Instant::now(); let step_id = command.id(); + let new_context_id = self.context.force_new_context(); self.context - .push(self.pool, &s.assignment_args, &s.variable_args) + .push( + self.pool, + &s.assignment_args, + &s.variable_args, + new_context_id, + ) .map_err(|e| Error::Checker { inner: e.into(), rule: "anchor".into(), @@ -156,7 +178,7 @@ impl<'c> ProofChecker<'c> { elaborator.open_subproof(s.commands.len()); } - if let Some(stats) = &mut self.config.statistics { + if let Some(stats) = &mut stats { let rule_name = match s.commands.last() { Some(ProofCommand::Step(step)) => format!("anchor({})", &step.rule), _ => "anchor".to_owned(), @@ -170,7 +192,7 @@ impl<'c> ProofChecker<'c> { } } ProofCommand::Assume { id, term } => { - if !self.check_assume(id, term, &proof.premises, &iter) { + if !self.check_assume(id, term, &proof.premises, &iter, &mut stats) { return Err(Error::Checker { inner: CheckerError::Assume(term.clone()), rule: "assume".into(), @@ -180,7 +202,7 @@ impl<'c> ProofChecker<'c> { } } } - if self.config.is_running_test || self.reached_empty_clause { + if self.reached_empty_clause { Ok(self.is_holey) } else { Err(Error::DoesNotReachEmptyClause) @@ -191,6 +213,24 @@ impl<'c> ProofChecker<'c> { self.elaborator = Some(Elaborator::new()); let result = self.check(&proof); + // We reset `self.elaborator` before returning any errors encountered while checking so + // we don't leave the checker in an invalid state + let mut elaborator = self.elaborator.take().unwrap(); + result?; + + proof.commands = elaborator.end(proof.commands); + + Ok((self.is_holey, proof)) + } + + pub fn check_and_elaborate_with_stats<'s, CR: CollectResults + Send + Default>( + &'s mut self, + mut proof: Proof, + stats: &'s mut CheckerStatistics, + ) -> CarcaraResult<(bool, Proof)> { + self.elaborator = Some(Elaborator::new()); + let result = self.check_with_stats(&proof, stats); + // We reset `self.elaborator` before returning any errors encountered while checking so we // don't leave the checker in an invalid state let mut elaborator = self.elaborator.take().unwrap(); @@ -198,27 +238,25 @@ impl<'c> ProofChecker<'c> { let elaboration_time = Instant::now(); proof.commands = elaborator.end(proof.commands); - if let Some(stats) = &mut self.config.statistics { - *stats.elaboration_time += elaboration_time.elapsed(); - } + stats.elaboration_time += elaboration_time.elapsed(); + Ok((self.is_holey, proof)) } - fn check_assume( + fn check_assume<'i, CR: CollectResults + Send + Default>( &mut self, id: &str, term: &Rc, - premises: &AHashSet>, - iter: &ProofIter, + premises: &IndexSet>, + iter: &'i ProofIter<'i>, + mut stats: &mut Option<&mut CheckerStatistics>, ) -> bool { let time = Instant::now(); - // Some subproofs contain `assume` commands inside them. These don't refer - // to the original problem premises, so we ignore the `assume` command if - // it is inside a subproof. Since the unit tests for the rules don't define the - // original problem, but sometimes use `assume` commands, we also skip the - // command if we are in a testing context. - if self.config.is_running_test || iter.is_in_subproof() { + // Some subproofs contain `assume` commands inside them. These don't refer to the original + // problem premises, but are instead local assumptions that are discharged by the subproof's + // final step, so we ignore the `assume` command if it is inside a subproof. + if iter.is_in_subproof() { if let Some(elaborator) = &mut self.elaborator { elaborator.assume(term); } @@ -226,9 +264,10 @@ impl<'c> ProofChecker<'c> { } if premises.contains(term) { - if let Some(s) = &mut self.config.statistics { + if let Some(s) = stats { let time = time.elapsed(); - *s.assume_time += time; + + s.assume_time += time; s.results .add_assume_measurement(s.file_name, id, true, time); } @@ -243,18 +282,18 @@ impl<'c> ProofChecker<'c> { } let mut found = None; - let mut deep_eq_time = Duration::ZERO; + let mut polyeq_time = Duration::ZERO; let mut core_time = Duration::ZERO; for p in premises { - let mut this_deep_eq_time = Duration::ZERO; - let (result, depth) = tracing_deep_eq(term, p, &mut this_deep_eq_time); - deep_eq_time += this_deep_eq_time; - if let Some(s) = &mut self.config.statistics { - s.results.add_deep_eq_depth(depth); + let mut this_polyeq_time = Duration::ZERO; + let (result, depth) = tracing_polyeq(term, p, &mut this_polyeq_time); + polyeq_time += this_polyeq_time; + if let Some(s) = &mut stats { + s.results.add_polyeq_depth(depth); } if result { - core_time = this_deep_eq_time; + core_time = this_polyeq_time; found = Some(p.clone()); break; } @@ -267,16 +306,17 @@ impl<'c> ProofChecker<'c> { elaborator.elaborate_assume(self.pool, p, term.clone(), id); - if let Some(s) = &mut self.config.statistics { - *s.elaboration_time += elaboration_time.elapsed(); + if let Some(s) = &mut stats { + s.elaboration_time += elaboration_time.elapsed(); } } - if let Some(s) = &mut self.config.statistics { + if let Some(s) = &mut stats { let time = time.elapsed(); - *s.assume_time += time; - *s.assume_core_time += core_time; - *s.deep_eq_time += deep_eq_time; + + s.assume_time += time; + s.assume_core_time += core_time; + s.polyeq_time += polyeq_time; s.results .add_assume_measurement(s.file_name, id, false, time); } @@ -284,24 +324,30 @@ impl<'c> ProofChecker<'c> { true } - fn check_step<'a>( + fn check_step<'i, CR: CollectResults + Send + Default>( &mut self, - step: &'a ProofStep, - previous_command: Option>, - iter: &'a ProofIter<'a>, + step: &ProofStep, + previous_command: Option, + iter: &'i ProofIter<'i>, + stats: &mut Option<&mut CheckerStatistics>, ) -> RuleResult { let time = Instant::now(); - let mut deep_eq_time = Duration::ZERO; + let mut polyeq_time = Duration::ZERO; + + if !step.discharge.is_empty() && step.rule != "subproof" { + return Err(CheckerError::Subproof(SubproofError::DischargeInWrongRule)); + } let mut elaborated = false; if step.rule == "lia_generic" { - if self.config.lia_via_cvc5 { - let is_hole = lia_generic::lia_generic( + if let Some(options) = &self.config.lia_options { + let is_hole = lia_generic::lia_generic_single_thread( self.pool, &step.clause, - &self.prelude, + self.prelude, self.elaborator.as_mut(), &step.id, + options, ); self.is_holey = self.is_holey || is_hole; elaborated = self.elaborator.is_some(); @@ -315,7 +361,7 @@ impl<'c> ProofChecker<'c> { } else { let rule = match Self::get_rule(&step.rule, self.config.strict) { Some(r) => r, - None if self.config.skip_unknown_rules => { + None if self.config.ignore_unknown_rules => { self.is_holey = true; if let Some(elaborator) = &mut self.elaborator { elaborator.unchanged(&step.clause); @@ -351,7 +397,7 @@ impl<'c> ProofChecker<'c> { context: &mut self.context, previous_command, discharge: &discharge, - deep_eq_time: &mut deep_eq_time, + polyeq_time: &mut polyeq_time, }; if let Some(elaborator) = &mut self.elaborator { @@ -367,18 +413,43 @@ impl<'c> ProofChecker<'c> { } } - if let Some(s) = &mut self.config.statistics { + if iter.is_end_step() { + let subproof = iter.current_subproof().unwrap(); + Self::check_discharge(subproof, iter.depth(), &step.discharge)?; + } + + if let Some(s) = stats { let time = time.elapsed(); + s.results .add_step_measurement(s.file_name, &step.id, &step.rule, time); - *s.deep_eq_time += deep_eq_time; + s.polyeq_time += polyeq_time; if elaborated { - *s.elaboration_time += time; + s.elaboration_time += time; } } Ok(()) } + fn check_discharge( + subproof: &[ProofCommand], + depth: usize, + discharge: &[(usize, usize)], + ) -> RuleResult { + let discharge: IndexSet<_> = discharge.iter().collect(); + if let Some((_, not_discharged)) = subproof + .iter() + .enumerate() + .find(|&(i, command)| command.is_assume() && !discharge.contains(&(depth, i))) + { + Err(CheckerError::Subproof( + SubproofError::LocalAssumeNotDischarged(not_discharged.id().to_owned()), + )) + } else { + Ok(()) + } + } + pub fn get_rule(rule_name: &str, strict: bool) -> Option { use rules::*; diff --git a/carcara/src/checker/parallel/mod.rs b/carcara/src/checker/parallel/mod.rs new file mode 100644 index 00000000..c5eac4c6 --- /dev/null +++ b/carcara/src/checker/parallel/mod.rs @@ -0,0 +1,492 @@ +pub mod scheduler; + +use super::{ + error::{CheckerError, SubproofError}, + lia_generic, + rules::{Premise, RuleArgs, RuleResult}, + Config, ProofChecker, +}; +use crate::benchmarking::{CollectResults, OnlineBenchmarkResults}; +use crate::checker::CheckerStatistics; +use crate::{ + ast::{pool::advanced::*, *}, + CarcaraResult, Error, +}; +use indexmap::IndexSet; +pub use scheduler::{Schedule, ScheduleIter, Scheduler}; +use std::{ + ops::ControlFlow, + sync::{atomic::AtomicBool, Arc}, + thread, + time::{Duration, Instant}, +}; + +pub struct ParallelProofChecker<'c> { + pool: Arc, + config: Config, + prelude: &'c ProblemPrelude, + context: ContextStack, + reached_empty_clause: bool, + is_holey: bool, + stack_size: usize, +} + +impl<'c> ParallelProofChecker<'c> { + pub fn new( + pool: Arc, + config: Config, + prelude: &'c ProblemPrelude, + context_usage: &Vec, + stack_size: usize, + ) -> Self { + ParallelProofChecker { + pool, + config, + prelude, + context: ContextStack::from_usage(context_usage), + reached_empty_clause: false, + is_holey: false, + stack_size, + } + } + + /// Copies the proof checker and instantiate parallel fields to be shared between threads + pub fn share(&self) -> Self { + ParallelProofChecker { + pool: self.pool.clone(), + config: self.config.clone(), + prelude: self.prelude, + context: ContextStack::from_previous(&self.context), + reached_empty_clause: false, + is_holey: false, + stack_size: self.stack_size, + } + } + + pub fn check(&mut self, proof: &Proof, scheduler: &Scheduler) -> CarcaraResult { + // Used to estimulate threads to abort prematurely (only happens when a + // thread already found out an invalid step) + let premature_abort = Arc::new(AtomicBool::new(false)); + let context_pool = ContextPool::from_global(&self.pool); + // + thread::scope(|s| { + let threads: Vec<_> = scheduler + .loads + .iter() + .enumerate() + .map(|(i, schedule)| { + // Shares the self between threads + let mut local_self = self.share(); + let local_pool = LocalPool::from_previous(&context_pool); + let should_abort = premature_abort.clone(); + + thread::Builder::new() + .name(format!("worker-{i}")) + .stack_size(self.stack_size) + .spawn_scoped(s, move || -> CarcaraResult<(bool, bool)> { + local_self.worker_thread_check( + proof, + schedule, + local_pool, + should_abort, + None::<&mut CheckerStatistics>, + ) + }) + .unwrap() + }) + .collect(); + + // Unify the results of all threads and generate the final result based on them + let (mut reached, mut holey) = (false, false); + let mut err: Result<_, Error> = Ok(()); + + // Wait until the threads finish and merge the results and statistics + threads + .into_iter() + .map(|t| t.join().unwrap()) + .try_for_each(|opt| { + match opt { + Ok((local_reached, local_holey)) => { + // Mask the result booleans + (reached, holey) = (reached | local_reached, holey | local_holey); + ControlFlow::Continue(()) + } + Err(e) => { + err = Err(e); + ControlFlow::Break(()) + } + } + }); + + // If an error happend + err?; + + if reached { + Ok(holey) + } else { + Err(Error::DoesNotReachEmptyClause) + } + }) + } + + pub fn check_with_stats( + &mut self, + proof: &Proof, + scheduler: &Scheduler, + stats: &mut CheckerStatistics, + ) -> CarcaraResult { + // Used to estimulate threads to abort prematurely (only happens when a + // thread already found out an invalid step) + let premature_abort = Arc::new(AtomicBool::new(false)); + let context_pool = ContextPool::from_global(&self.pool); + // + thread::scope(|s| { + let threads: Vec<_> = scheduler + .loads + .iter() + .enumerate() + .map(|(i, schedule)| { + let mut local_stats = CheckerStatistics { + file_name: "", + elaboration_time: Duration::ZERO, + polyeq_time: Duration::ZERO, + assume_time: Duration::ZERO, + assume_core_time: Duration::ZERO, + results: CR::default(), + }; + // Shares the proof checker between threads + let mut local_self = self.share(); + let local_pool = LocalPool::from_previous(&context_pool); + let should_abort = premature_abort.clone(); + + thread::Builder::new() + .name(format!("worker-{i}")) + .stack_size(self.stack_size) + .spawn_scoped( + s, + move || -> CarcaraResult<(bool, bool, CheckerStatistics)> { + local_self + .worker_thread_check( + proof, + schedule, + local_pool, + should_abort, + Some(&mut local_stats), + ) + .map(|r| (r.0, r.1, local_stats)) + }, + ) + .unwrap() + }) + .collect(); + + // Unify the results of all threads and generate the final result based on them + let (mut reached, mut holey) = (false, false); + let mut err: Result<_, Error> = Ok(()); + + // Wait until the threads finish and merge the results and statistics + threads + .into_iter() + .map(|t| t.join().unwrap()) + .for_each(|opt| { + match opt { + Ok((local_reached, local_holey, mut local_stats)) => { + // Combine the statistics + // Takes the external and local benchmark results to local variables and combine them + let main = std::mem::take(&mut stats.results); + let to_merge = std::mem::take(&mut local_stats.results); + stats.results = CR::combine(main, to_merge); + + // Make sure other times are updated + stats.elaboration_time += local_stats.elaboration_time; + stats.polyeq_time += local_stats.polyeq_time; + stats.assume_time += local_stats.assume_time; + stats.assume_core_time += local_stats.assume_core_time; + + // Mask the result booleans + (reached, holey) = (reached | local_reached, holey | local_holey); + } + Err(e) => { + // Since we want the statistics of the whole run + // (even in a error case) we cannot abort at this + // point, since we can have more threads to be + // evaluated and their statistics colleted + err = Err(e); + } + } + }); + + // If an error happend + err?; + + if reached { + Ok(holey) + } else { + Err(Error::DoesNotReachEmptyClause) + } + }) + } + + fn worker_thread_check( + &mut self, + proof: &Proof, + schedule: &Schedule, + mut pool: LocalPool, + should_abort: Arc, + mut stats: Option<&mut CheckerStatistics>, + ) -> CarcaraResult<(bool, bool)> { + use std::sync::atomic::Ordering; + + let mut iter = schedule.iter(&proof.commands[..]); + let mut last_depth = 0; + + while let Some(command) = iter.next() { + // If there is any depth difference between the current and last step + while (last_depth - iter.depth() as i64 > 0) + || (last_depth - iter.depth() as i64 == 0 + && matches!(command, ProofCommand::Subproof(_))) + { + // If this is the last command of a subproof, we have to pop off the subproof + // commands of the stack. The parser already ensures that the last command + // in a subproof is always a `step` command + self.context.pop(); + last_depth -= 1; + } + last_depth = iter.depth() as i64; + + match command { + ProofCommand::Step(step) => { + // If this step ends a subproof, it might need to implicitly reference the + // previous command in the subproof + let previous_command = if iter.is_end_step() { + let subproof = iter.current_subproof().unwrap(); + let index = subproof.len() - 2; + subproof + .get(index) + .map(|command| Premise::new((iter.depth(), index), command)) + } else { + None + }; + + self.check_step(step, previous_command, &iter, &mut pool, &mut stats) + .map_err(|e| { + // Signalize to other threads to stop the proof checking + should_abort.store(true, Ordering::Release); + Error::Checker { + inner: e, + rule: step.rule.clone(), + step: step.id.clone(), + } + })?; + + if step.clause.is_empty() { + self.reached_empty_clause = true; + } + } + ProofCommand::Subproof(s) => { + let time = Instant::now(); + let step_id = command.id(); + + self.context + .push( + &mut pool.ctx_pool, + &s.assignment_args, + &s.variable_args, + s.context_id, + ) + .map_err(|e| { + // Signalize to other threads to stop the proof checking + should_abort.store(true, Ordering::Release); + Error::Checker { + inner: e.into(), + rule: "anchor".into(), + step: step_id.to_owned(), + } + })?; + + if let Some(stats) = &mut stats { + // Collects statistics + let rule_name = match s.commands.last() { + Some(ProofCommand::Step(step)) => { + format!("anchor({})", &step.rule) + } + _ => "anchor".to_owned(), + }; + stats.results.add_step_measurement( + stats.file_name, + step_id, + &rule_name, + time.elapsed(), + ); + } + } + ProofCommand::Assume { id, term } => { + if !self.check_assume(id, term, &proof.premises, &iter, &mut stats) { + // Signalize to other threads to stop the proof checking + should_abort.store(true, Ordering::Release); + return Err(Error::Checker { + inner: CheckerError::Assume(term.clone()), + rule: "assume".into(), + step: id.clone(), + }); + } + } + } + // Verify if any of the other threads found an error and abort in case of positive + if should_abort.load(Ordering::Acquire) { + break; + } + } + + // Returns Ok(reached empty clause, isHoley) + if self.reached_empty_clause { + Ok((true, self.is_holey)) + } else { + Ok((false, self.is_holey)) + } + } + + fn check_assume( + &mut self, + id: &str, + term: &Rc, + premises: &IndexSet>, + iter: &ScheduleIter, + mut stats: &mut Option<&mut CheckerStatistics>, + ) -> bool { + let time = Instant::now(); + + // Similarly to the single-threaded checker, we ignore `assume` commands that are inside + // subproofs + if iter.is_in_subproof() { + return true; + } + + if premises.contains(term) { + if let Some(s) = stats { + let time = time.elapsed(); + s.assume_time += time; + s.results + .add_assume_measurement(s.file_name, id, true, time); + } + return true; + } + + if self.config.strict { + return false; + } + + let mut found = None; + let mut polyeq_time = Duration::ZERO; + let mut core_time = Duration::ZERO; + + for p in premises { + let mut this_polyeq_time = Duration::ZERO; + let (result, depth) = tracing_polyeq(term, p, &mut this_polyeq_time); + polyeq_time += this_polyeq_time; + if let Some(s) = &mut stats { + s.results.add_polyeq_depth(depth); + } + if result { + core_time = this_polyeq_time; + found = Some(p.clone()); + break; + } + } + + if found.is_none() { + return false; + } + + if let Some(s) = stats { + let time = time.elapsed(); + s.assume_time += time; + s.assume_core_time += core_time; + s.polyeq_time += polyeq_time; + s.results + .add_assume_measurement(s.file_name, id, false, time); + } + + true + } + + fn check_step( + &mut self, + step: &ProofStep, + previous_command: Option, + iter: &ScheduleIter, + pool: &mut LocalPool, + stats: &mut Option<&mut CheckerStatistics>, + ) -> RuleResult { + let time = Instant::now(); + let mut polyeq_time = Duration::ZERO; + + if !step.discharge.is_empty() && step.rule != "subproof" { + return Err(CheckerError::Subproof(SubproofError::DischargeInWrongRule)); + } + + if step.rule == "lia_generic" { + if let Some(options) = &self.config.lia_options { + let is_hole = + lia_generic::lia_generic_multi_thread(&step.clause, self.prelude, options); + self.is_holey = self.is_holey || is_hole; + } else { + log::warn!("encountered \"lia_generic\" rule, ignoring"); + self.is_holey = true; + } + } else { + let rule = match ProofChecker::get_rule(&step.rule, self.config.strict) { + Some(r) => r, + None if self.config.ignore_unknown_rules => { + self.is_holey = true; + return Ok(()); + } + None => return Err(CheckerError::UnknownRule), + }; + + if step.rule == "hole" { + self.is_holey = true; + } + + let premises: Vec<_> = step + .premises + .iter() + .map(|&p| { + let command = iter.get_premise(p); + Premise::new(p, command) + }) + .collect(); + let discharge: Vec<_> = step + .discharge + .iter() + .map(|&i| iter.get_premise(i)) + .collect(); + + let rule_args = RuleArgs { + conclusion: &step.clause, + premises: &premises, + args: &step.args, + pool, + context: &mut self.context, + previous_command, + discharge: &discharge, + polyeq_time: &mut polyeq_time, + }; + + rule(rule_args)?; + } + + if iter.is_end_step() { + let subproof = iter.current_subproof().unwrap(); + ProofChecker::check_discharge(subproof, iter.depth(), &step.discharge)?; + } + + if let Some(s) = stats { + let time = time.elapsed(); + s.results + .add_step_measurement(s.file_name, &step.id, &step.rule, time); + s.polyeq_time += polyeq_time; + } + Ok(()) + } +} diff --git a/carcara/src/checker/parallel/scheduler.rs b/carcara/src/checker/parallel/scheduler.rs new file mode 100644 index 00000000..b5988033 --- /dev/null +++ b/carcara/src/checker/parallel/scheduler.rs @@ -0,0 +1,415 @@ +use crate::ast::{Proof, ProofCommand}; +use std::{ + cmp::Ordering, + collections::{BinaryHeap, HashSet}, +}; + +/// Struct responsible for storing a thread work schedule. +/// +/// Here, each step from the original proof is represented as a tuple: +/// (depth, subproof index). The first element is the subproof nesting `depth` +/// (in the subproof stack) and `subproof index` is the index where this step is +/// located in the subproof vector. +#[derive(Clone, Default)] +pub struct Schedule { + steps: Vec<(usize, usize)>, +} + +impl Schedule { + pub fn new() -> Self { + Self::default() + } + + /// Inserts a new step into the end of the schedule steps vector + pub fn push(&mut self, cmd: (usize, usize)) { + self.steps.push(cmd); + } + + /// Removes the last step from the end of the steps vector + pub fn pop(&mut self) { + self.steps.pop(); + } + + /// Returns the last schedule step + pub fn last(&self) -> Option<&(usize, usize)> { + self.steps.last() + } + + /// Returns an iterator over the proof commands. See [`ScheduleIter`]. + pub fn iter<'a>(&'a self, proof: &'a [ProofCommand]) -> ScheduleIter { + ScheduleIter::new(proof, &self.steps) + } +} + +// ============================================================================= + +/// Represents the current load assigned for an specific schedule. +/// `0`: Current work load +/// `1`: Schedule index +#[derive(Eq)] +struct AssignedLoad(u64, usize); + +impl Ord for AssignedLoad { + fn cmp(&self, other: &Self) -> Ordering { + other.0.cmp(&self.0) + } +} + +impl PartialOrd for AssignedLoad { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl PartialEq for AssignedLoad { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +/// Represents a level in the proof stack. It holds the subproof itself, +/// its prerequisite step (anchor) and which schedules used any step inside +/// this layer +struct StackLevel<'a> { + id: usize, + cmds: &'a [ProofCommand], + pre_req: Option<(usize, usize)>, + used_by: HashSet, +} + +impl<'a> StackLevel<'a> { + pub fn new(id: usize, cmds: &'a [ProofCommand], pre_req: Option<(usize, usize)>) -> Self { + Self { + id, + cmds, + pre_req, + used_by: HashSet::new(), + } + } +} + +/// Struct that stores the schedules for each thread. +pub struct Scheduler { + pub loads: Vec, +} + +impl Scheduler { + /// Creates a thread scheduler for this proof using a specific number of + /// workers. This scheduler is responsible for balancing the load (the + /// proof steps have different costs to be checked) aiming for minimum + /// amount of async overhead. + /// + /// Returns a scheduler itself and context usage info (a vector holding + /// how many threads are going to use each of the contexts. This vector maps + /// the contexts based in the subproof hashing value (i.e. `subproof_id`) + /// created in the parser). + pub fn new(num_workers: usize, proof: &Proof) -> (Self, Vec) { + // Initializes the control and result variables + let cmds = &proof.commands; + let mut loads = vec![Schedule::new(); num_workers]; + let mut stack = vec![StackLevel::new(0, cmds, None)]; + let mut pq = BinaryHeap::::new(); + let mut context_usage = vec![]; + for i in 0..num_workers { + pq.push(AssignedLoad(0, i)); + } + + loop { + // Pop the finished subproofs + while !stack.is_empty() && { + let top = stack.last().unwrap(); + top.id == top.cmds.len() + } { + for schedule_id in &stack.last().unwrap().used_by { + let last = loads[*schedule_id].last().unwrap(); + // If it's an useless context insertion + if last.0 < stack.len() + && matches!(stack[last.0].cmds[last.1], ProofCommand::Subproof(_)) + { + // Make sure this context usage count is reduced + let subproof_id = match &stack[last.0].cmds[last.1] { + ProofCommand::Subproof(s) => s.context_id, + _ => unreachable!(), + }; + context_usage[subproof_id] -= 1; + + loads[*schedule_id].pop(); + } + // Creates a closing step for each schedule that used this subproof + else { + loads[*schedule_id].push((stack.len() - 1, usize::MAX)); + } + } + stack.pop(); + } + if stack.is_empty() { + break; + } + // + let AssignedLoad(mut load, load_index) = pq.pop().unwrap(); + { + let top = stack.last().unwrap(); + let step_weight = get_step_weight(&top.cmds[top.id]); + load = load + .checked_add(step_weight) + .expect("Weight balancing overflow!"); + pq.push(AssignedLoad(load, load_index)); + } + + let depth = stack.len() - 1; + let mut i = 1; + let initial_layer = { + let tmp = loads[load_index].last().unwrap_or(&(0, 0)); + if tmp.1 == usize::MAX { + tmp.0 - 1 + } else { + tmp.0 + } + }; + // If this step needs the context of the subproof oppening step + // but it was not assigned to this schedule yet + while initial_layer + i <= depth { + let subproof_oppening = stack[initial_layer + i].pre_req.unwrap(); + let last_inserted = *loads[load_index].last().unwrap_or(&(usize::MAX, 0)); + + if last_inserted != subproof_oppening { + loads[load_index].push(subproof_oppening); + stack[subproof_oppening.0].used_by.insert(load_index); + + // Now this subproof is used by another schedule + let subproof_id = match &stack[subproof_oppening.0].cmds[subproof_oppening.1] { + ProofCommand::Subproof(s) => s.context_id, + _ => unreachable!(), + }; + context_usage[subproof_id] += 1; + } + i += 1; + } + + let top = stack.last_mut().unwrap(); + // Assign a step to some Schedule + loads[load_index].push((depth, top.id)); + top.used_by.insert(load_index); + + // Go to next step + let last_id = top.id; + top.id += 1; + if let ProofCommand::Subproof(s) = &top.cmds[last_id] { + stack.push(StackLevel::new(0, &s.commands, Some((depth, last_id)))); + stack.last_mut().unwrap().used_by.insert(load_index); + // First schedule using this subproof + context_usage.push(1); + } + } + (Scheduler { loads }, context_usage) + } +} + +/// Iterates through schedule steps +pub struct ScheduleIter<'a> { + proof_stack: Vec<&'a [ProofCommand]>, + steps: &'a Vec<(usize, usize)>, + step_id: usize, +} + +impl<'a> ScheduleIter<'a> { + pub fn new(proof_commands: &'a [ProofCommand], steps: &'a Vec<(usize, usize)>) -> Self { + Self { + proof_stack: vec![proof_commands], + steps, + step_id: 0, + } + } + + /// Returns the current nesting depth of the iterator, or more precisely, + /// the nesting depth of the last step that was returned. This depth starts + /// at zero, for steps in the root proof. + pub fn depth(&self) -> usize { + self.proof_stack.len() - 1 + } + + /// Returns `true` if the iterator is currently in a subproof, that is, if + /// its depth is greater than zero. + pub fn is_in_subproof(&self) -> bool { + self.depth() > 0 + } + + /// Returns a slice to the commands of the inner-most open subproof. + pub fn current_subproof(&self) -> Option<&[ProofCommand]> { + self.is_in_subproof() + .then(|| *self.proof_stack.last().unwrap()) + } + + /// Returns `true` if the most recently returned step is the last step of + /// the current subproof. + pub fn is_end_step(&self) -> bool { + self.is_in_subproof() + && self.steps[self.step_id - 1].1 == self.proof_stack.last().unwrap().len() - 1 + } + + /// Returns the command referenced by a premise index of the form (depth, index in subproof). + /// This method may panic if the premise index does not refer to a valid command. + pub fn get_premise(&self, (depth, index): (usize, usize)) -> &ProofCommand { + &self.proof_stack[depth][index] + } +} + +impl<'a> Iterator for ScheduleIter<'a> { + type Item = &'a ProofCommand; + + fn next(&mut self) -> Option { + // If it is the end of the steps + if self.step_id >= self.steps.len() { + return None; + } + + // If current step is an closing subproof step + while let (_, usize::MAX) = self.steps[self.step_id] { + self.proof_stack.pop(); + self.step_id += 1; + // If reached the last closing step of the whole proof + if self.step_id == self.steps.len() { + return None; + } + } + let cur_step = self.steps[self.step_id]; + self.step_id += 1; + + let top = self.proof_stack.last().unwrap(); + let command = &top[cur_step.1]; + // Opens a new subproof + if let ProofCommand::Subproof(subproof) = command { + self.proof_stack.push(&subproof.commands); + } + Some(command) + } +} + +/// Function that returns a weight associated with a specific rule. These +/// weights are directly correlated to carcara (Single Thread/previous version) +/// median performance while solving each of those rules. +/// +/// Even though subproofs should have a weight (since it has a high cost to be +/// computed), it's for better of scheduler architecture that subproofs have a +/// null weight. +/// +/// If you're interested in these weight values, take a look at [Carcara's +/// paper](https://hanielbarbosa.com/papers/tacas2023.pdf) +/// published at TACAS in April 2023 and its benchmark data. +/// +/// The rules with null weight are rules that we had no info about the median +/// performance, since the solver used in the paper dataset does not generate +/// these rules. +pub fn get_step_weight(step: &ProofCommand) -> u64 { + match step { + ProofCommand::Assume { .. } => 230, + ProofCommand::Subproof(_) => 0, + ProofCommand::Step(s) => { + match &s.rule as &str { + "assume" => 230, + "true" => 0, //-1 + "false" => 263, + "not_not" => 574, + "and_pos" => 361, + "and_neg" => 607, + "or_pos" => 640, + "or_neg" => 460, + "xor_pos1" => 763, + "xor_pos2" => 345, + "xor_neg1" => 0, //-1 + "xor_neg2" => 0, //-1 + "implies_pos" => 394, + "implies_neg1" => 214, + "implies_neg2" => 287, + "equiv_pos1" => 763, + "equiv_pos2" => 541, + "equiv_neg1" => 434, + "equiv_neg2" => 476, + "ite_pos1" => 804, + "ite_pos2" => 344, + "ite_neg1" => 566, + "ite_neg2" => 542, + "eq_reflexive" => 451, + "eq_transitive" => 780, + "eq_congruent" => 722, + "eq_congruent_pred" => 632, + "distinct_elim" => 812, + "la_rw_eq" => 1091, + "la_generic" => 87564, + "la_disequality" => 919, + "la_totality" => 0, //-1 + "la_tautology" => 4291, + "forall_inst" => 7877, + "qnt_join" => 2347, + "qnt_rm_unused" => 3659, + "resolution" => 7491, + "th_resolution" => 2462, + "refl" => 1305, + "trans" => 575, + "cong" => 984, + "ho_cong" => 0, //-1 + "and" => 493, + "tautology" => 0, //-1 + "not_or" => 476, + "or" => 426, + "not_and" => 927, + "xor1" => 0, //-1 + "xor2" => 0, //-1 + "not_xor1" => 0, //-1 + "not_xor2" => 0, //-1 + "implies" => 788, + "not_implies1" => 402, + "not_implies2" => 484, + "equiv1" => 837, + "equiv2" => 812, + "not_equiv1" => 418, + "not_equiv2" => 451, + "ite1" => 509, + "ite2" => 493, + "not_ite1" => 722, + "not_ite2" => 476, + "ite_intro" => 3192, + "contraction" => 1731, + "connective_def" => 705, + "ite_simplify" => 1797, + "eq_simplify" => 845, + "and_simplify" => 1165, + "or_simplify" => 1133, + "not_simplify" => 787, + "implies_simplify" => 1231, + "equiv_simplify" => 1337, + "bool_simplify" => 1436, + "qnt_simplify" => 517, + "div_simplify" => 2117, + "prod_simplify" => 2527, + "unary_minus_simplify" => 0, //-1 + "minus_simplify" => 1059, + "sum_simplify" => 2248, + "comp_simplify" => 1781, + "nary_elim" => 0, //-1 + "ac_simp" => 9781, + "bfun_elim" => 8558, + "bind" => 5924, + "qnt_cnf" => 14244, + "subproof" => 262, + "let" => 4718, + "onepoint" => 7787, + "sko_ex" => 9321, + "sko_forall" => 12242, + "reordering" => 1452, + "symm" => 682, + "not_symm" => 0, //-1 + "eq_symmetric" => 673, + "or_intro" => 508, + "bind_let" => 2324, + "la_mult_pos" => 1446, + "la_mult_neg" => 1447, + "hole" => 185, //Debug only + "trust" => 185, //Debug only + "strict_resolution" => 1276, + + _ => 0, + } + } + } +} diff --git a/carcara/src/checker/rules/clausification.rs b/carcara/src/checker/rules/clausification.rs index c21b0550..2dbed31b 100644 --- a/carcara/src/checker/rules/clausification.rs +++ b/carcara/src/checker/rules/clausification.rs @@ -1,10 +1,9 @@ use super::{ - assert_clause_len, assert_deep_eq_is_expected, assert_eq, assert_is_expected, - assert_num_premises, assert_operation_len, get_premise_term, CheckerError, EqualityError, - RuleArgs, RuleResult, + assert_clause_len, assert_eq, assert_is_expected, assert_num_premises, assert_operation_len, + assert_polyeq_expected, get_premise_term, CheckerError, EqualityError, RuleArgs, RuleResult, }; use crate::ast::*; -use ahash::AHashMap; +use indexmap::IndexMap; pub fn distinct_elim(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { assert_clause_len(conclusion, 1)?; @@ -23,7 +22,7 @@ pub fn distinct_elim(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult } // If there are more than two boolean arguments to the distinct operator, the // second term must be `false` - args if *pool.sort(&args[0]) == Sort::Bool => { + args if pool.sort(&args[0]).as_sort().unwrap() == &Sort::Bool => { if second_term.is_bool_false() { Ok(()) } else { @@ -207,7 +206,12 @@ pub fn nary_elim(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { /// A function to expand terms that fall in the right or left associative cases. For example, /// the term `(=> p q r s)` will be expanded into the term `(=> p (=> q (=> r s)))`. - fn expand_assoc(pool: &mut TermPool, op: Operator, args: &[Rc], case: Case) -> Rc { + fn expand_assoc( + pool: &mut dyn TermPool, + op: Operator, + args: &[Rc], + case: Case, + ) -> Rc { let (head, tail) = match args { [] => unreachable!(), [t] => return t.clone(), @@ -260,7 +264,7 @@ pub fn nary_elim(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { /// The first simplification step for `bfun_elim`, that expands quantifiers over boolean variables. fn bfun_elim_first_step( - pool: &mut TermPool, + pool: &mut dyn TermPool, bindigns: &[SortedVar], term: &Rc, acc: &mut Vec>, @@ -284,13 +288,15 @@ fn bfun_elim_first_step( /// The second simplification step for `bfun_elim`, that expands function applications over /// non-constant boolean arguments into `ite` terms. fn bfun_elim_second_step( - pool: &mut TermPool, + pool: &mut dyn TermPool, func: &Rc, args: &[Rc], processed: usize, ) -> Rc { for i in processed..args.len() { - if *pool.sort(&args[i]) == Sort::Bool && !args[i].is_bool_false() && !args[i].is_bool_true() + if pool.sort(&args[i]).as_sort().unwrap() == &Sort::Bool + && !args[i].is_bool_false() + && !args[i].is_bool_true() { let mut ite_args = Vec::with_capacity(3); ite_args.push(args[i].clone()); @@ -312,9 +318,9 @@ fn bfun_elim_second_step( /// Applies the simplification steps for the `bfun_elim` rule. fn apply_bfun_elim( - pool: &mut TermPool, + pool: &mut dyn TermPool, term: &Rc, - cache: &mut AHashMap, Rc>, + cache: &mut IndexMap, Rc>, ) -> Result, SubstitutionError> { if let Some(v) = cache.get(term) { return Ok(v.clone()); @@ -385,7 +391,7 @@ pub fn bfun_elim( conclusion, premises, pool, - deep_eq_time, + polyeq_time, .. }: RuleArgs, ) -> RuleResult { @@ -394,8 +400,8 @@ pub fn bfun_elim( let psi = get_premise_term(&premises[0])?; - let expected = apply_bfun_elim(pool, psi, &mut AHashMap::new())?; - assert_deep_eq_is_expected(&conclusion[0], expected, deep_eq_time) + let expected = apply_bfun_elim(pool, psi, &mut IndexMap::new())?; + assert_polyeq_expected(&conclusion[0], expected, polyeq_time) } #[cfg(test)] diff --git a/carcara/src/checker/rules/extras.rs b/carcara/src/checker/rules/extras.rs index e936b1c3..7174f766 100644 --- a/carcara/src/checker/rules/extras.rs +++ b/carcara/src/checker/rules/extras.rs @@ -5,7 +5,7 @@ use super::{ EqualityError, RuleArgs, RuleResult, }; use crate::{ast::*, checker::rules::assert_operation_len}; -use ahash::AHashSet; +use indexmap::IndexSet; pub fn reordering(RuleArgs { conclusion, premises, .. }: RuleArgs) -> RuleResult { assert_num_premises(premises, 1)?; @@ -13,8 +13,8 @@ pub fn reordering(RuleArgs { conclusion, premises, .. }: RuleArgs) -> RuleResult let premise = premises[0].clause; assert_clause_len(conclusion, premise.len())?; - let premise_set: AHashSet<_> = premise.iter().collect(); - let conclusion_set: AHashSet<_> = conclusion.iter().collect(); + let premise_set: IndexSet<_> = premise.iter().collect(); + let conclusion_set: IndexSet<_> = conclusion.iter().collect(); if let Some(&t) = premise_set.difference(&conclusion_set).next() { Err(CheckerError::ContractionMissingTerm(t.clone())) } else if let Some(&t) = conclusion_set.difference(&premise_set).next() { @@ -80,8 +80,8 @@ pub fn bind_let( let (left, right) = match_term_err!((= l r) = &conclusion[0])?; - let (l_bindings, left) = left.unwrap_let_err()?; - let (r_bindings, right) = right.unwrap_let_err()?; + let (l_bindings, left) = left.as_let_err()?; + let (r_bindings, right) = right.as_let_err()?; if l_bindings.len() != r_bindings.len() { return Err(EqualityError::ExpectedEqual(l_bindings.clone(), r_bindings.clone()).into()); @@ -149,7 +149,7 @@ fn la_mult_generic(conclusion: &[Rc], is_pos: bool) -> RuleResult { CheckerError::ExpectedNumber(Rational::new(), zero.clone()) ); - let (op, args) = original.unwrap_op_err()?; + let (op, args) = original.as_op_err()?; assert_operation_len(op, args, 2)?; let (l, r) = (&args[0], &args[1]); @@ -317,7 +317,7 @@ mod tests { (step t1.t1 (cl (= x y)) :rule hole) (step t1 (cl (= (let ((a 0)) x) (let ((b 0)) y))) :rule bind_let)": false, } - "Deep equality in variable values" { + "Polyequality in variable values" { "(anchor :step t1 :args ((x Int) (y Int))) (step t1.t1 (cl (= (= 0 1) (= 1 0))) :rule hole) (step t1.t2 (cl (= x y)) :rule hole) diff --git a/carcara/src/checker/rules/linear_arithmetic.rs b/carcara/src/checker/rules/linear_arithmetic.rs index d81f2d06..4bd81066 100644 --- a/carcara/src/checker/rules/linear_arithmetic.rs +++ b/carcara/src/checker/rules/linear_arithmetic.rs @@ -3,7 +3,7 @@ use crate::{ ast::*, checker::error::{CheckerError, LinearArithmeticError}, }; -use ahash::AHashMap; +use indexmap::{map::Entry, IndexMap}; use rug::{ops::NegAssign, Integer, Rational}; pub fn la_rw_eq(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { @@ -62,11 +62,11 @@ fn negate_disequality(term: &Rc) -> Result<(Operator, LinearComb, LinearCo /// plus a constant term. This is also used to represent a disequality, in which case the left side /// is the non-constant terms and their coefficients, and the right side is the constant term. #[derive(Debug)] -pub struct LinearComb(pub(crate) AHashMap, Rational>, pub(crate) Rational); +pub struct LinearComb(pub(crate) IndexMap, Rational>, pub(crate) Rational); impl LinearComb { fn new() -> Self { - Self(AHashMap::new(), Rational::new()) + Self(IndexMap::new(), Rational::new()) } /// Flattens a term and adds it to the linear combination, multiplying by the coefficient @@ -125,8 +125,6 @@ impl LinearComb { } fn insert(&mut self, key: Rc, value: Rational) { - use std::collections::hash_map::Entry; - match self.0.entry(key) { Entry::Occupied(mut e) => { *e.get_mut() += value; @@ -185,7 +183,7 @@ impl LinearComb { } let mut result = self.1.numer().clone(); - for (_, coeff) in self.0.iter() { + for (_, coeff) in &self.0 { if result == 1 { return Integer::from(1); } diff --git a/carcara/src/checker/rules/mod.rs b/carcara/src/checker/rules/mod.rs index e244c6b9..6ba0962f 100644 --- a/carcara/src/checker/rules/mod.rs +++ b/carcara/src/checker/rules/mod.rs @@ -18,7 +18,7 @@ pub struct RuleArgs<'a> { pub(super) conclusion: &'a [Rc], pub(super) premises: &'a [Premise<'a>], pub(super) args: &'a [ProofArg], - pub(super) pool: &'a mut TermPool, + pub(super) pool: &'a mut dyn TermPool, pub(super) context: &'a mut ContextStack, // For rules that end a subproof, we need to pass the previous command in the subproof that it @@ -27,7 +27,7 @@ pub struct RuleArgs<'a> { pub(super) previous_command: Option>, pub(super) discharge: &'a [&'a ProofCommand], - pub(super) deep_eq_time: &'a mut Duration, + pub(super) polyeq_time: &'a mut Duration, } #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] @@ -128,19 +128,15 @@ where Ok(()) } -fn assert_deep_eq(a: &Rc, b: &Rc, time: &mut Duration) -> Result<(), CheckerError> { - if !deep_eq(a, b, time) { +fn assert_polyeq(a: &Rc, b: &Rc, time: &mut Duration) -> Result<(), CheckerError> { + if !polyeq(a, b, time) { return Err(EqualityError::ExpectedEqual(a.clone(), b.clone()).into()); } Ok(()) } -fn assert_deep_eq_is_expected( - got: &Rc, - expected: Rc, - time: &mut Duration, -) -> RuleResult { - if !deep_eq(got, &expected, time) { +fn assert_polyeq_expected(got: &Rc, expected: Rc, time: &mut Duration) -> RuleResult { + if !polyeq(got, &expected, time) { return Err(EqualityError::ExpectedToBe { expected, got: got.clone() }.into()); } Ok(()) @@ -155,34 +151,43 @@ fn assert_is_bool_constant(got: &Rc, expected: bool) -> RuleResult { #[cfg(test)] fn run_tests(test_name: &str, definitions: &str, cases: &[(&str, bool)]) { - use crate::{ - checker::{Config, ProofChecker}, - parser::parse_instance, - }; + use crate::{checker, parser}; use std::io::Cursor; for (i, (proof, expected)) in cases.iter().enumerate() { // This parses the definitions again for every case, which is not ideal - let (prelude, parsed, mut pool) = parse_instance( + let (prelude, mut proof, mut pool) = parser::parse_instance( Cursor::new(definitions), Cursor::new(proof), - true, - false, - false, + parser::Config::new(), ) .unwrap_or_else(|e| panic!("parser error during test \"{}\": {}", test_name, e)); - let mut checker = ProofChecker::new( - &mut pool, - Config { - strict: false, - skip_unknown_rules: false, - is_running_test: true, - statistics: None, - lia_via_cvc5: false, - }, - prelude, - ); - let got = checker.check(&parsed).is_ok(); + + // Since rule tests often use `assume` commands to introduce premises, we search the proof + // for all `assume`d terms and retroactively add them as the problem premises, to avoid + // checker errors on the `assume`s + proof.premises = proof + .commands + .iter() + .filter_map(|c| match c { + ProofCommand::Assume { term, .. } => Some(term.clone()), + _ => None, + }) + .collect(); + + // All proofs must eventually reach the empty clause, so to avoid having to add a dummy + // `(step end (cl) :rule hole)` to every rule test, we add this dummy step here + proof.commands.push(ProofCommand::Step(ProofStep { + id: "end".into(), + clause: Vec::new(), + rule: "hole".into(), + premises: Vec::new(), + args: Vec::new(), + discharge: Vec::new(), + })); + + let mut checker = checker::ProofChecker::new(&mut pool, checker::Config::new(), &prelude); + let got = checker.check(&proof).is_ok(); assert_eq!( *expected, got, "test case \"{}\" index {} failed", diff --git a/carcara/src/checker/rules/quantifier.rs b/carcara/src/checker/rules/quantifier.rs index be63a46b..2cf2f780 100644 --- a/carcara/src/checker/rules/quantifier.rs +++ b/carcara/src/checker/rules/quantifier.rs @@ -1,17 +1,13 @@ use super::{ - assert_clause_len, assert_deep_eq_is_expected, assert_eq, assert_is_expected, assert_num_args, + assert_clause_len, assert_eq, assert_is_expected, assert_num_args, assert_polyeq_expected, CheckerError, RuleArgs, RuleResult, }; use crate::{ast::*, checker::error::QuantifierError, utils::DedupIterator}; -use ahash::{AHashMap, AHashSet}; +use indexmap::{IndexMap, IndexSet}; pub fn forall_inst( RuleArgs { - conclusion, - args, - pool, - deep_eq_time, - .. + conclusion, args, pool, polyeq_time, .. }: RuleArgs, ) -> RuleResult { assert_clause_len(conclusion, 1)?; @@ -23,12 +19,12 @@ pub fn forall_inst( // Since the bindings and arguments may not be in the same order, we collect the bindings into // a hash set, and remove each binding from it as we find the associated argument - let mut bindings: AHashSet<_> = bindings.iter().cloned().collect(); - let substitution: AHashMap<_, _> = args + let mut bindings: IndexSet<_> = bindings.iter().cloned().collect(); + let substitution: IndexMap<_, _> = args .iter() .map(|arg| { let (arg_name, arg_value) = arg.as_assign()?; - let arg_sort = pool.add(Term::Sort(pool.sort(arg_value).clone())); + let arg_sort = pool.sort(arg_value); rassert!( bindings.remove(&(arg_name.clone(), arg_sort.clone())), QuantifierError::NoBindingMatchesArg(arg_name.clone()) @@ -46,10 +42,9 @@ pub fn forall_inst( QuantifierError::NoArgGivenForBinding(bindings.iter().next().unwrap().0.clone()) ); - // Equalities may be reordered in the final term, so we need to use deep equality modulo - // reordering + // Equalities may be reordered in the final term, so we need to compare for polyequality here let expected = substitution.apply(pool, original); - assert_deep_eq_is_expected(substituted, expected, deep_eq_time) + assert_polyeq_expected(substituted, expected, polyeq_time) } pub fn qnt_join(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { @@ -57,9 +52,9 @@ pub fn qnt_join(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { let (left, right) = match_term_err!((= l r) = &conclusion[0])?; - let (q_1, bindings_1, left) = left.unwrap_quant_err()?; - let (q_2, bindings_2, left) = left.unwrap_quant_err()?; - let (q_3, bindings_3, right) = right.unwrap_quant_err()?; + let (q_1, bindings_1, left) = left.as_quant_err()?; + let (q_2, bindings_2, left) = left.as_quant_err()?; + let (q_3, bindings_3, right) = right.as_quant_err()?; assert_eq(&q_1, &q_2)?; assert_eq(&q_2, &q_3)?; @@ -81,9 +76,9 @@ pub fn qnt_rm_unused(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult assert_clause_len(conclusion, 1)?; let (left, right) = match_term_err!((= l r) = &conclusion[0])?; - let (q_1, bindings_1, phi_1) = left.unwrap_quant_err()?; + let (q_1, bindings_1, phi_1) = left.as_quant_err()?; - let (bindings_2, phi_2) = match right.unwrap_quant() { + let (bindings_2, phi_2) = match right.as_quant() { Some((q_2, b, t)) => { assert_eq(&q_1, &q_2)?; (b, t) @@ -96,7 +91,7 @@ pub fn qnt_rm_unused(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult assert_eq(phi_1, phi_2)?; // Cloning here may be unnecessary - let free_vars = pool.free_vars(phi_1).clone(); + let free_vars = pool.free_vars(phi_1); let expected: Vec<_> = bindings_1 .iter() @@ -112,10 +107,10 @@ pub fn qnt_rm_unused(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult /// Converts a term into negation normal form, expanding all connectives. fn negation_normal_form( - pool: &mut TermPool, + pool: &mut dyn TermPool, term: &Rc, polarity: bool, - cache: &mut AHashMap<(Rc, bool), Rc>, + cache: &mut IndexMap<(Rc, bool), Rc>, ) -> Rc { if let Some(v) = cache.get(&(term.clone(), polarity)) { return v.clone(); @@ -153,13 +148,13 @@ fn negation_normal_form( true => build_term!(pool, (and (or {a} {b}) (or {c} {d}))), false => build_term!(pool, (or (and {a} {b}) (and {c} {d}))), } - } else if let Some((quant, bindings, inner)) = term.unwrap_quant() { + } else if let Some((quant, bindings, inner)) = term.as_quant() { let quant = if polarity { quant } else { !quant }; let inner = negation_normal_form(pool, inner, polarity, cache); pool.add(Term::Quant(quant, bindings.clone(), inner)) } else { match match_term!((= p q) = term) { - Some((left, right)) if *pool.sort(left) == Sort::Bool => { + Some((left, right)) if pool.sort(left).as_sort().unwrap() == &Sort::Bool => { let a = negation_normal_form(pool, left, !polarity, cache); let b = negation_normal_form(pool, right, polarity, cache); let c = negation_normal_form(pool, right, !polarity, cache); @@ -221,7 +216,7 @@ fn distribute(formulas: &[CnfFormula]) -> CnfFormula { /// Prenex all universal quantifiers in a term. This doesn't prenex existential quantifiers. This /// assumes the term is in negation normal form. -fn prenex_forall(pool: &mut TermPool, acc: &mut C, term: &Rc) -> Rc +fn prenex_forall(pool: &mut dyn TermPool, acc: &mut C, term: &Rc) -> Rc where C: Extend, { @@ -265,8 +260,8 @@ pub fn qnt_cnf(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { let (l_bindings, phi, r_bindings, phi_prime) = { let (l, r) = match_term_err!((or (not l) r) = &conclusion[0])?; - let (l_q, l_b, phi) = l.unwrap_quant_err()?; - let (r_q, r_b, phi_prime) = r.unwrap_quant_err()?; + let (l_q, l_b, phi) = l.as_quant_err()?; + let (r_q, r_b, phi_prime) = r.as_quant_err()?; // We expect both quantifiers to be `forall` assert_is_expected(&l_q, Quantifier::Forall)?; @@ -275,10 +270,10 @@ pub fn qnt_cnf(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { (l_b, phi, r_b, phi_prime) }; - let r_bindings = r_bindings.iter().cloned().collect::>(); - let mut new_bindings = l_bindings.iter().cloned().collect::>(); + let r_bindings = r_bindings.iter().cloned().collect::>(); + let mut new_bindings = l_bindings.iter().cloned().collect::>(); let clauses: Vec<_> = { - let nnf = negation_normal_form(pool, phi, true, &mut AHashMap::new()); + let nnf = negation_normal_form(pool, phi, true, &mut IndexMap::new()); let prenexed = prenex_forall(pool, &mut new_bindings, &nnf); let cnf = conjunctive_normal_form(&prenexed); cnf.into_iter() @@ -292,7 +287,7 @@ pub fn qnt_cnf(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { // `new_bindings` contains all bindings that existed in the original term, plus all bindings // added by the prenexing step. All bindings in the right side must be in this set - if let Some((var, _)) = r_bindings.iter().find(|b| !new_bindings.contains(b)) { + if let Some((var, _)) = r_bindings.iter().find(|&b| !new_bindings.contains(b)) { return Err(CheckerError::Quant( QuantifierError::CnfNewBindingIntroduced(var.clone()), )); @@ -304,7 +299,7 @@ pub fn qnt_cnf(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { .ok_or_else(|| QuantifierError::ClauseDoesntAppearInCnf(phi_prime.clone()))?; // Cloning here may be unnecessary - let free_vars = pool.free_vars(selected_clause).clone(); + let free_vars = pool.free_vars(selected_clause); // While all bindings in `r_bindings` must also be in `new_bindings`, the same is not true in // the opposite direction. That is because some variables from the set may be omitted in the @@ -469,8 +464,8 @@ mod tests { use super::*; use crate::parser::tests::*; - fn to_cnf_term(pool: &mut TermPool, term: &Rc) -> Rc { - let nnf = negation_normal_form(pool, term, true, &mut AHashMap::new()); + fn to_cnf_term(pool: &mut dyn TermPool, term: &Rc) -> Rc { + let nnf = negation_normal_form(pool, term, true, &mut IndexMap::new()); let mut bindings = Vec::new(); let prenexed = prenex_forall(pool, &mut bindings, &nnf); let cnf = conjunctive_normal_form(&prenexed); @@ -502,7 +497,7 @@ mod tests { fn run_tests(definitions: &str, cases: &[(&str, &str)]) { for &(term, expected) in cases { - let mut pool = TermPool::new(); + let mut pool = crate::ast::pool::PrimitivePool::new(); let [term, expected] = parse_terms(&mut pool, definitions, [term, expected]); let got = to_cnf_term(&mut pool, &term); assert_eq!(expected, got); diff --git a/carcara/src/checker/rules/reflexivity.rs b/carcara/src/checker/rules/reflexivity.rs index a826f302..531b1587 100644 --- a/carcara/src/checker/rules/reflexivity.rs +++ b/carcara/src/checker/rules/reflexivity.rs @@ -12,7 +12,7 @@ pub fn refl( conclusion, pool, context, - deep_eq_time, + polyeq_time, .. }: RuleArgs, ) -> RuleResult { @@ -23,7 +23,7 @@ pub fn refl( // If the two terms are directly identical, we don't need to do any more work. We make sure to // do this check before we try to get the context substitution, because `refl` can be used // outside of any subproof - if are_alpha_equivalent(left, right, deep_eq_time) { + if alpha_equiv(left, right, polyeq_time) { return Ok(()); } @@ -36,10 +36,10 @@ pub fn refl( // don't compute the new left and right terms until they are needed, to avoid doing unnecessary // work let new_left = context.apply(pool, left); - let result = are_alpha_equivalent(&new_left, right, deep_eq_time) || { + let result = alpha_equiv(&new_left, right, polyeq_time) || { let new_right = context.apply(pool, right); - are_alpha_equivalent(left, &new_right, deep_eq_time) - || are_alpha_equivalent(&new_left, &new_right, deep_eq_time) + alpha_equiv(left, &new_right, polyeq_time) + || alpha_equiv(&new_left, &new_right, polyeq_time) }; rassert!( result, @@ -75,14 +75,14 @@ pub fn strict_refl(RuleArgs { conclusion, pool, context, .. }: RuleArgs) -> Rule fn elaborate_equality( elaborator: &mut Elaborator, - pool: &mut TermPool, + pool: &mut dyn TermPool, left: &Rc, right: &Rc, id: &str, - deep_eq_time: &mut std::time::Duration, + polyeq_time: &mut std::time::Duration, ) -> (usize, usize) { - let is_alpha_equivalence = !deep_eq(left, right, deep_eq_time); - elaborator.elaborate_deep_eq(pool, id, left.clone(), right.clone(), is_alpha_equivalence) + let is_alpha_equivalence = !polyeq(left, right, polyeq_time); + elaborator.elaborate_polyeq(pool, id, left.clone(), right.clone(), is_alpha_equivalence) } pub fn elaborate_refl( @@ -90,7 +90,7 @@ pub fn elaborate_refl( conclusion, pool, context, - deep_eq_time, + polyeq_time, .. }: RuleArgs, command_id: String, @@ -122,12 +122,12 @@ pub fn elaborate_refl( // directly. In the second case, we need to first apply the context to the left term, using a // `refl` step, and then prove the equivalence of the new left term with the right term. In the // third case, we also need to apply the context to the right term, using another `refl` step. - if are_alpha_equivalent(left, right, deep_eq_time) { + if alpha_equiv(left, right, polyeq_time) { let equality_step = - elaborate_equality(elaborator, pool, left, right, &command_id, deep_eq_time); + elaborate_equality(elaborator, pool, left, right, &command_id, polyeq_time); let id = elaborator.get_new_id(&command_id); - // TODO: Elaborating the deep equality will add new commands to the accumulator, but + // TODO: Elaborating the polyequality will add new commands to the accumulator, but // currently we can't push them as the elaborated step directly, so we need to add this // dummy `reordering` step. elaborator.push_elaborated_step(ProofStep { @@ -142,15 +142,9 @@ pub fn elaborate_refl( let id = elaborator.get_new_id(&command_id); let first_step = elaborator.add_refl_step(pool, left.clone(), new_left.clone(), id); - if are_alpha_equivalent(&new_left, right, deep_eq_time) { - let second_step = elaborate_equality( - elaborator, - pool, - &new_left, - right, - &command_id, - deep_eq_time, - ); + if alpha_equiv(&new_left, right, polyeq_time) { + let second_step = + elaborate_equality(elaborator, pool, &new_left, right, &command_id, polyeq_time); let id = elaborator.get_new_id(&command_id); elaborator.push_elaborated_step(ProofStep { id, @@ -160,15 +154,9 @@ pub fn elaborate_refl( args: Vec::new(), discharge: Vec::new(), }); - } else if are_alpha_equivalent(&new_left, &new_right, deep_eq_time) { - let second_step = elaborate_equality( - elaborator, - pool, - &new_left, - right, - &command_id, - deep_eq_time, - ); + } else if alpha_equiv(&new_left, &new_right, polyeq_time) { + let second_step = + elaborate_equality(elaborator, pool, &new_left, right, &command_id, polyeq_time); let id = elaborator.get_new_id(&command_id); let third_step = elaborator.add_refl_step(pool, new_right.clone(), right.clone(), id); diff --git a/carcara/src/checker/rules/resolution.rs b/carcara/src/checker/rules/resolution.rs index 1de940e1..08bc8301 100644 --- a/carcara/src/checker/rules/resolution.rs +++ b/carcara/src/checker/rules/resolution.rs @@ -7,8 +7,8 @@ use crate::{ checker::{error::ResolutionError, Elaborator}, utils::DedupIterator, }; -use ahash::{AHashMap, AHashSet}; -use std::{collections::hash_map::Entry, iter::FromIterator}; +use indexmap::{map::Entry, IndexMap, IndexSet}; +use std::iter::FromIterator; type ResolutionTerm<'a> = (u32, &'a Rc); @@ -34,7 +34,7 @@ impl<'a> ClauseCollection<'a> for Vec> { } } -impl<'a> ClauseCollection<'a> for AHashSet> { +impl<'a> ClauseCollection<'a> for IndexSet> { fn insert_term(&mut self, item: ResolutionTerm<'a>) { self.insert(item); } @@ -45,7 +45,7 @@ impl<'a> ClauseCollection<'a> for AHashSet> { } /// Undoes the transformation done by `Rc::remove_all_negations`. -fn unremove_all_negations(pool: &mut TermPool, (n, term): ResolutionTerm) -> Rc { +fn unremove_all_negations(pool: &mut dyn TermPool, (n, term): ResolutionTerm) -> Rc { let mut term = term.clone(); for _ in 0..n { term = build_term!(pool, (not { term })); @@ -94,7 +94,7 @@ struct ResolutionTrace { fn greedy_resolution( conclusion: &[Rc], premises: &[Premise], - pool: &mut TermPool, + pool: &mut dyn TermPool, tracing: bool, ) -> Result { // If we are elaborating, we record which pivot was found for each binary resolution step, so we @@ -113,21 +113,21 @@ fn greedy_resolution( // Without looking at the conclusion, it is unclear if the (not p) term should be removed by the // p term, or if the (not (not p)) should be removed by the (not (not (not p))). We can only // determine this by looking at the conclusion and using it to derive the pivots. - let conclusion: AHashSet<_> = conclusion + let conclusion: IndexSet<_> = conclusion .iter() .map(Rc::remove_all_negations) .map(|(n, t)| (n as i32, t)) .collect(); // The working clause contains the terms from the conclusion clause that we already encountered - let mut working_clause = AHashSet::new(); + let mut working_clause = IndexSet::new(); // The pivots are the encountered terms that are not present in the conclusion clause, and so // should be removed. After being used to eliminate a term, a pivot can still be used to // eliminate other terms. Because of that, we represent the pivots as a hash map to a boolean, // which represents if the pivot was already eliminated or not. At the end, this boolean should // be true for all pivots - let mut pivots = AHashMap::new(); + let mut pivots = IndexMap::new(); for premise in premises { // Only one pivot may be eliminated per clause. This restriction is required so logically @@ -249,7 +249,7 @@ fn greedy_resolution( } fn rup_resolution(conclusion: &[Rc], premises: &[Premise]) -> bool { - let mut clauses: Vec)>> = premises + let mut clauses: Vec)>> = premises .iter() .map(|p| { p.clause @@ -260,7 +260,7 @@ fn rup_resolution(conclusion: &[Rc], premises: &[Premise]) -> bool { .collect(); clauses.extend(conclusion.iter().map(|t| { let (p, t) = t.remove_all_negations_with_polarity(); - let mut clause = AHashSet::new(); + let mut clause = IndexSet::new(); clause.insert((!p, t)); clause })); @@ -294,9 +294,9 @@ pub fn resolution_with_args( conclusion, premises, args, pool, .. }: RuleArgs, ) -> RuleResult { - let resolution_result = apply_generic_resolution::>(premises, args, pool)?; + let resolution_result = apply_generic_resolution::>(premises, args, pool)?; - let conclusion: AHashSet<_> = conclusion.iter().map(Rc::remove_all_negations).collect(); + let conclusion: IndexSet<_> = conclusion.iter().map(Rc::remove_all_negations).collect(); if let Some(extra) = conclusion.difference(&resolution_result).next() { let extra = unremove_all_negations(pool, *extra); @@ -341,7 +341,7 @@ pub fn strict_resolution( fn apply_generic_resolution<'a, C: ClauseCollection<'a>>( premises: &'a [Premise], args: &'a [ProofArg], - pool: &mut TermPool, + pool: &mut dyn TermPool, ) -> Result { assert_num_premises(premises, 2..)?; let num_steps = premises.len() - 1; @@ -377,7 +377,7 @@ fn apply_generic_resolution<'a, C: ClauseCollection<'a>>( } fn binary_resolution<'a, C: ClauseCollection<'a>>( - pool: &mut TermPool, + pool: &mut dyn TermPool, current: &mut C, next: &'a [Rc], pivot: ResolutionTerm<'a>, @@ -551,7 +551,7 @@ pub fn tautology(RuleArgs { conclusion, premises, .. }: RuleArgs) -> RuleResult assert_is_bool_constant(&conclusion[0], true)?; let premise = premises[0].clause; - let mut seen = AHashSet::with_capacity(premise.len()); + let mut seen = IndexSet::with_capacity(premise.len()); let with_negations_removed = premise.iter().map(Rc::remove_all_negations_with_polarity); for (polarity, term) in with_negations_removed { if seen.contains(&(!polarity, term)) { @@ -565,8 +565,8 @@ pub fn tautology(RuleArgs { conclusion, premises, .. }: RuleArgs) -> RuleResult pub fn contraction(RuleArgs { conclusion, premises, .. }: RuleArgs) -> RuleResult { assert_num_premises(premises, 1)?; - let premise_set: AHashSet<_> = premises[0].clause.iter().collect(); - let conclusion_set: AHashSet<_> = conclusion.iter().collect(); + let premise_set: IndexSet<_> = premises[0].clause.iter().collect(); + let conclusion_set: IndexSet<_> = conclusion.iter().collect(); if let Some(&t) = premise_set.difference(&conclusion_set).next() { Err(CheckerError::ContractionMissingTerm(t.clone())) } else if let Some(&t) = conclusion_set.difference(&premise_set).next() { diff --git a/carcara/src/checker/rules/simplification.rs b/carcara/src/checker/rules/simplification.rs index 9e33e096..e06f2567 100644 --- a/carcara/src/checker/rules/simplification.rs +++ b/carcara/src/checker/rules/simplification.rs @@ -3,7 +3,7 @@ use super::{ RuleResult, }; use crate::{ast::*, utils::DedupIterator}; -use ahash::{AHashMap, AHashSet}; +use indexmap::{IndexMap, IndexSet}; use rug::Rational; /// A macro to define the possible transformations for a "simplify" rule. @@ -40,15 +40,15 @@ macro_rules! simplify { fn generic_simplify_rule( conclusion: &[Rc], - pool: &mut TermPool, - simplify_function: fn(&Term, &mut TermPool) -> Option>, + pool: &mut dyn TermPool, + simplify_function: fn(&Term, &mut dyn TermPool) -> Option>, ) -> RuleResult { assert_clause_len(conclusion, 1)?; let mut simplify_until_fixed_point = |term: &Rc, goal: &Rc| -> Result, CheckerError> { let mut current = term.clone(); - let mut seen = AHashSet::new(); + let mut seen = IndexSet::new(); loop { if !seen.insert(current.clone()) { return Err(CheckerError::CycleInSimplification(current)); @@ -160,7 +160,7 @@ pub fn eq_simplify(args: RuleArgs) -> RuleResult { /// Used for both the `and_simplify` and `or_simplify` rules, depending on `rule_kind`. `rule_kind` /// has to be either `Operator::And` or `Operator::Or`. fn generic_and_or_simplify( - pool: &mut TermPool, + pool: &mut dyn TermPool, conclusion: &[Rc], rule_kind: Operator, ) -> RuleResult { @@ -213,14 +213,14 @@ fn generic_and_or_simplify( // Then, we remove all duplicate terms. We do this in place to avoid another allocation. // Similarly to the step that removes the "skip term", we check if we already found the result // after this step. This is also necessary in some examples - let mut seen = AHashSet::with_capacity(phis.len()); + let mut seen = IndexSet::with_capacity(phis.len()); phis.retain(|t| seen.insert(t.clone())); if result_args.iter().eq(&phis) { return Ok(()); } // Finally, we check to see if the result was short-circuited - let seen: AHashSet<(bool, &Rc)> = phis + let seen: IndexSet<(bool, &Rc)> = phis .iter() .map(Rc::remove_all_negations_with_polarity) .collect(); @@ -349,10 +349,10 @@ pub fn equiv_simplify(args: RuleArgs) -> RuleResult { (= phi_1 false): (phi_1, _) => build_term!(pool, (not {phi_1.clone()})), // This is a special case for the `equiv_simplify` rule that was added to make - // elaboration of deep equalities less verbose. This transformation can very easily lead - // to cycles, so it must always be the last transformation rule. Unfortunately, this - // means that failed simplifications in the `equiv_simplify` rule will frequently reach - // this transformation and reach a cycle, in which case the error message may be a bit + // elaboration of polyequality less verbose. This transformation can very easily lead to + // cycles, so it must always be the last transformation rule. Unfortunately, this means + // that failed simplifications in the `equiv_simplify` rule will frequently reach this + // transformation and reach a cycle, in which case the error message may be a bit // confusing. // // phi_1 = phi_2 => phi_2 = phi_1 @@ -407,7 +407,7 @@ pub fn bool_simplify(args: RuleArgs) -> RuleResult { pub fn qnt_simplify(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { assert_clause_len(conclusion, 1)?; let (left, right) = match_term_err!((= l r) = &conclusion[0])?; - let (_, _, inner) = left.unwrap_quant_err()?; + let (_, _, inner) = left.as_quant_err()?; rassert!( inner.is_bool_false() || inner.is_bool_true(), CheckerError::ExpectedAnyBoolConstant(inner.clone()) @@ -430,7 +430,7 @@ pub fn div_simplify(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { CheckerError::ExpectedNumber(Rational::new(), right.clone()) ); Ok(()) - } else if t_2.as_number().map_or(false, |n| n == 1) { + } else if t_2.as_number().is_some_and(|n| n == 1) { assert_eq(right, t_1) } else { let expected = t_1.as_signed_number_err()? / t_2.as_signed_number_err()?; @@ -445,7 +445,7 @@ pub fn div_simplify(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { /// Used for both the `sum_simplify` and `prod_simplify` rules, depending on `rule_kind`. /// `rule_kind` has to be either `Operator::Add` or `Operator::Mult`. fn generic_sum_prod_simplify_rule( - pool: &mut TermPool, + pool: &mut dyn TermPool, ts: &Rc, u: &Rc, rule_kind: Operator, @@ -531,7 +531,7 @@ fn generic_sum_prod_simplify_rule( // Finally, we verify that the constant and the remaining arguments are what we expect rassert!(u_constant == constant_total && u_args.iter().eq(result), { let expected = { - let mut expected_args = vec![pool.add(Term::Terminal(Terminal::Real(constant_total)))]; + let mut expected_args = vec![pool.add(Term::new_real(constant_total))]; expected_args.extend(u_args.iter().cloned()); pool.add(Term::Op(rule_kind, expected_args)) }; @@ -557,7 +557,7 @@ pub fn minus_simplify(RuleArgs { conclusion, .. }: RuleArgs) -> RuleResult { // the `minus_simplify` and the `unary_minus_simplify` rules fn try_unary_minus_simplify(t: &Rc, u: &Rc) -> bool { // First case of `unary_minus_simplify` - if match_term!((-(-t)) = t).map_or(false, |t| t == u) { + if match_term!((-(-t)) = t) == Some(u) { return true; } @@ -667,8 +667,8 @@ pub fn comp_simplify(args: RuleArgs) -> RuleResult { } fn apply_ac_simp( - pool: &mut TermPool, - cache: &mut AHashMap, Rc>, + pool: &mut dyn TermPool, + cache: &mut IndexMap, Rc>, term: &Rc, ) -> Rc { if let Some(t) = cache.get(term) { @@ -723,7 +723,7 @@ pub fn ac_simp(RuleArgs { conclusion, pool, .. }: RuleArgs) -> RuleResult { let (original, flattened) = match_term_err!((= psi phis) = &conclusion[0])?; assert_eq( flattened, - &apply_ac_simp(pool, &mut AHashMap::new(), original), + &apply_ac_simp(pool, &mut IndexMap::new(), original), ) } diff --git a/carcara/src/checker/rules/subproof.rs b/carcara/src/checker/rules/subproof.rs index 060b083b..7a347d3a 100644 --- a/carcara/src/checker/rules/subproof.rs +++ b/carcara/src/checker/rules/subproof.rs @@ -3,7 +3,7 @@ use super::{ CheckerError, EqualityError, RuleArgs, RuleResult, }; use crate::{ast::*, checker::error::SubproofError}; -use ahash::{AHashMap, AHashSet}; +use indexmap::{IndexMap, IndexSet}; pub fn subproof( RuleArgs { @@ -64,14 +64,14 @@ pub fn bind( // While the documentation indicates this rule is only called with `forall` quantifiers, in // some of the tests examples it is also called with the `exists` quantifier - let (l_quant, l_bindings, left) = left.unwrap_quant_err()?; - let (r_quant, r_bindings, right) = right.unwrap_quant_err()?; + let (l_quant, l_bindings, left) = left.as_quant_err()?; + let (r_quant, r_bindings, right) = right.as_quant_err()?; assert_eq(&l_quant, &r_quant)?; let [l_bindings, r_bindings] = [l_bindings, r_bindings].map(|b| { b.iter() .map(|var| pool.add(var.clone().into())) - .collect::>() + .collect::>() }); // The terms in the quantifiers must be phi and phi' @@ -91,9 +91,10 @@ pub fn bind( // Since we are closing a subproof, we only care about the substitutions that were introduced // in it let context = context.last().unwrap(); + let context = context.as_ref().unwrap(); // The quantifier binders must be the xs and ys of the context substitution - let (xs, ys): (AHashSet<_>, AHashSet<_>) = context + let (xs, ys): (IndexSet<_>, IndexSet<_>) = context .mappings .iter() // We skip terms which are not simply variables @@ -142,8 +143,15 @@ pub fn r#let( // Since we are closing a subproof, we only care about the substitutions that were introduced // in it - let substitution: AHashMap, Rc> = - context.last().unwrap().mappings.iter().cloned().collect(); + let substitution: IndexMap, Rc> = context + .last() + .unwrap() + .as_ref() + .unwrap() + .mappings + .iter() + .cloned() + .collect(); let (let_term, u_prime) = match_term_err!((= l u) = &conclusion[0])?; let Term::Let(let_bindings, u) = let_term.as_ref() else { @@ -166,7 +174,7 @@ pub fn r#let( let mut pairs: Vec<_> = let_bindings .iter() .map(|(x, t)| { - let sort = pool.add(Term::Sort(pool.sort(t).clone())); + let sort = pool.sort(t); let x_term = pool.add((x.clone(), sort).into()); let s = substitution .get(&x_term) @@ -191,15 +199,15 @@ pub fn r#let( Ok(()) } -fn extract_points(quant: Quantifier, term: &Rc) -> AHashSet<(Rc, Rc)> { - fn find_points(acc: &mut AHashSet<(Rc, Rc)>, polarity: bool, term: &Rc) { +fn extract_points(quant: Quantifier, term: &Rc) -> IndexSet<(Rc, Rc)> { + fn find_points(acc: &mut IndexSet<(Rc, Rc)>, polarity: bool, term: &Rc) { // This does not make use of a cache, so there may be performance issues // TODO: Measure the performance of this function, and see if a cache is needed if let Some(inner) = term.remove_negation() { return find_points(acc, !polarity, inner); } - if let Some((_, _, inner)) = term.unwrap_quant() { + if let Some((_, _, inner)) = term.as_quant() { return find_points(acc, polarity, inner); } match polarity { @@ -225,7 +233,7 @@ fn extract_points(quant: Quantifier, term: &Rc) -> AHashSet<(Rc, Rc< } } - let mut result = AHashSet::new(); + let mut result = IndexSet::new(); find_points(&mut result, quant == Quantifier::Exists, term); result } @@ -244,8 +252,8 @@ pub fn onepoint( assert_clause_len(conclusion, 1)?; let (left, right) = match_term_err!((= l r) = &conclusion[0])?; - let (quant, l_bindings, left) = left.unwrap_quant_err()?; - let (r_bindings, right) = match right.unwrap_quant() { + let (quant, l_bindings, left) = left.as_quant_err()?; + let (r_bindings, right) = match right.as_quant() { Some((q, b, t)) => { assert_eq(&q, &quant)?; (b, t) @@ -265,27 +273,32 @@ pub fn onepoint( } ); - let last_context = context.last_mut().unwrap(); - if let Some((var, _)) = r_bindings - .iter() - .find(|b| !last_context.bindings.contains(b)) - { + let last_context = context.last().unwrap(); + if let Some((var, _)) = { + let last_context = last_context.as_ref().unwrap(); + r_bindings + .iter() + .find(|&b| !last_context.bindings.contains(b)) + } { return Err(SubproofError::BindingIsNotInContext(var.clone()).into()); } - let l_bindings_set: AHashSet<_> = l_bindings + let l_bindings_set: IndexSet<_> = l_bindings .iter() .map(|var| pool.add(var.clone().into())) .collect(); - let r_bindings_set: AHashSet<_> = r_bindings + let r_bindings_set: IndexSet<_> = r_bindings .iter() .map(|var| pool.add(var.clone().into())) .collect(); - let substitution_vars: AHashSet<_> = last_context + let substitution_vars: IndexSet<_> = last_context + .as_ref() + .unwrap() .mappings .iter() .map(|(k, _)| k.clone()) .collect(); + drop(last_context); let points = extract_points(quant, left); @@ -293,13 +306,14 @@ pub fn onepoint( // substitution to the points in order to replace these variables by their value. We also // create a duplicate of every point in the reverse order, since the order of equalities may be // flipped - let points: AHashSet<_> = points + let points: IndexSet<_> = points .into_iter() .flat_map(|(x, t)| [(x.clone(), t.clone()), (t, x)]) .map(|(x, t)| (x, context.apply(pool, &t))) .collect(); - let last_context = context.last_mut().unwrap(); + let last_context = context.last().unwrap(); + let last_context = last_context.as_ref().unwrap(); // For each substitution (:= x t) in the context, the equality (= x t) must appear in phi if let Some((k, v)) = last_context .mappings @@ -330,7 +344,7 @@ fn generic_skolemization_rule( pool, context, previous_command, - deep_eq_time, + polyeq_time, .. }: RuleArgs, ) -> RuleResult { @@ -340,7 +354,7 @@ fn generic_skolemization_rule( let (left, psi) = match_term_err!((= l r) = &conclusion[0])?; - let (quant, bindings, phi) = left.unwrap_quant_err()?; + let (quant, bindings, phi) = left.as_quant_err()?; assert_is_expected(&quant, rule_type)?; let previous_term = get_premise_term(&previous_command)?; @@ -353,8 +367,15 @@ fn generic_skolemization_rule( current_phi = context.apply_previous(pool, ¤t_phi); } - let substitution: AHashMap, Rc> = - context.last().unwrap().mappings.iter().cloned().collect(); + let substitution: IndexMap, Rc> = context + .last() + .unwrap() + .as_ref() + .unwrap() + .mappings + .iter() + .cloned() + .collect(); for (i, x) in bindings.iter().enumerate() { let x_term = pool.add(Term::from(x.clone())); let t = substitution @@ -382,7 +403,7 @@ fn generic_skolemization_rule( } pool.add(Term::Choice(x.clone(), inner)) }; - if !are_alpha_equivalent(t, &expected, deep_eq_time) { + if !alpha_equiv(t, &expected, polyeq_time) { return Err(EqualityError::ExpectedEqual(t.clone(), expected).into()); } diff --git a/carcara/src/checker/rules/tautology.rs b/carcara/src/checker/rules/tautology.rs index 508127f5..ecbd17db 100644 --- a/carcara/src/checker/rules/tautology.rs +++ b/carcara/src/checker/rules/tautology.rs @@ -1,5 +1,5 @@ use super::{ - assert_clause_len, assert_deep_eq, assert_eq, assert_num_premises, get_premise_term, + assert_clause_len, assert_eq, assert_num_premises, assert_polyeq, get_premise_term, CheckerError, RuleArgs, RuleResult, }; use crate::{ast::*, checker::rules::assert_operation_len}; @@ -258,7 +258,7 @@ pub fn not_ite2(RuleArgs { conclusion, premises, .. }: RuleArgs) -> RuleResult { assert_eq(phi_2, conclusion[1].remove_negation_err()?) } -pub fn ite_intro(RuleArgs { conclusion, deep_eq_time, .. }: RuleArgs) -> RuleResult { +pub fn ite_intro(RuleArgs { conclusion, polyeq_time, .. }: RuleArgs) -> RuleResult { assert_clause_len(conclusion, 1)?; let (root_term, right_side) = match_term_err!((= t u) = &conclusion[0])?; @@ -278,13 +278,13 @@ pub fn ite_intro(RuleArgs { conclusion, deep_eq_time, .. }: RuleArgs) -> RuleRes // ``` // For cases like this, we first check if `t` equals the right side term modulo reordering of // equalities. If not, we unwrap the conjunction and continue checking the rule normally. - if deep_eq(root_term, right_side, deep_eq_time) { + if polyeq(root_term, right_side, polyeq_time) { return Ok(()); } let us = match_term_err!((and ...) = right_side)?; // `us` must be a conjunction where the first term is the root term - assert_deep_eq(&us[0], root_term, deep_eq_time)?; + assert_polyeq(&us[0], root_term, polyeq_time)?; // The remaining terms in `us` should be of the correct form for u_i in &us[1..] { @@ -292,11 +292,11 @@ pub fn ite_intro(RuleArgs { conclusion, deep_eq_time, .. }: RuleArgs) -> RuleRes let mut is_valid = |r_1, s_1, r_2, s_2| { // s_1 == s_2 == (ite cond r_1 r_2) - if deep_eq(s_1, s_2, deep_eq_time) { + if polyeq(s_1, s_2, polyeq_time) { if let Some((a, b, c)) = match_term!((ite a b c) = s_1) { - return deep_eq(a, cond, deep_eq_time) - && deep_eq(b, r_1, deep_eq_time) - && deep_eq(c, r_2, deep_eq_time); + return polyeq(a, cond, polyeq_time) + && polyeq(b, r_1, polyeq_time) + && polyeq(c, r_2, polyeq_time); } } false diff --git a/carcara/src/checker/rules/transitivity.rs b/carcara/src/checker/rules/transitivity.rs index 10b0ab5e..17dc4d9b 100644 --- a/carcara/src/checker/rules/transitivity.rs +++ b/carcara/src/checker/rules/transitivity.rs @@ -180,7 +180,7 @@ pub fn elaborate_eq_transitive( if !not_needed.is_empty() { let mut clause = latest_clause; - clause.extend(not_needed.into_iter()); + clause.extend(not_needed); let or_intro_step = ProofStep { id: elaborator.get_new_id(&command_id), clause, @@ -204,7 +204,7 @@ pub fn elaborate_eq_transitive( } fn flip_eq_transitive_premises( - pool: &mut TermPool, + pool: &mut dyn TermPool, elaborator: &mut Elaborator, new_eq_transitive_step: (usize, usize), new_clause: &[Rc], diff --git a/carcara/src/checker/elaboration/accumulator.rs b/carcara/src/elaborator/accumulator.rs similarity index 98% rename from carcara/src/checker/elaboration/accumulator.rs rename to carcara/src/elaborator/accumulator.rs index 4525a8d3..1c145ba4 100644 --- a/carcara/src/checker/elaboration/accumulator.rs +++ b/carcara/src/elaborator/accumulator.rs @@ -66,6 +66,7 @@ impl Accumulator { commands, assignment_args, variable_args, + context_id: 0, }) } diff --git a/carcara/src/checker/elaboration/diff.rs b/carcara/src/elaborator/diff.rs similarity index 93% rename from carcara/src/checker/elaboration/diff.rs rename to carcara/src/elaborator/diff.rs index 65f25c60..26f8d84d 100644 --- a/carcara/src/checker/elaboration/diff.rs +++ b/carcara/src/elaborator/diff.rs @@ -57,10 +57,7 @@ pub fn apply_diff(root: ProofDiff, proof: Vec) -> Vec { + (_, CommandDiff::Step(mut elaboration)) => { f.result.commands.append(&mut elaboration); } (_, CommandDiff::Delete) => (), diff --git a/carcara/src/checker/elaboration/mod.rs b/carcara/src/elaborator/mod.rs similarity index 94% rename from carcara/src/checker/elaboration/mod.rs rename to carcara/src/elaborator/mod.rs index fe08bc6f..291b19aa 100644 --- a/carcara/src/checker/elaboration/mod.rs +++ b/carcara/src/elaborator/mod.rs @@ -1,13 +1,14 @@ mod accumulator; -mod deep_eq; mod diff; +mod polyeq; mod pruning; -use crate::{ast::*, utils::SymbolTable}; +pub use diff::{apply_diff, CommandDiff, ProofDiff}; +pub use pruning::{prune_proof, slice_proof}; + +use crate::{ast::*, utils::HashMapStack}; use accumulator::Accumulator; -use deep_eq::DeepEqElaborator; -use diff::{apply_diff, CommandDiff, ProofDiff}; -use pruning::prune_proof; +use polyeq::PolyeqElaborator; #[derive(Debug, Default)] struct Frame { @@ -33,7 +34,7 @@ impl Frame { #[derive(Debug)] pub struct Elaborator { stack: Vec, - seen_clauses: SymbolTable>, usize>, + seen_clauses: HashMapStack>, usize>, accumulator: Accumulator, } @@ -48,7 +49,7 @@ impl Elaborator { Self { stack: vec![Frame::default()], accumulator: Accumulator::new(), - seen_clauses: SymbolTable::new(), + seen_clauses: HashMapStack::new(), } } @@ -199,7 +200,7 @@ impl Elaborator { /// index must already be mapped to the new index space. pub fn add_symm_step( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, original_premise: (usize, usize), original_equality: (Rc, Rc), id: String, @@ -220,7 +221,7 @@ impl Elaborator { /// Adds a `refl` step that asserts that the two given terms are equal. pub fn add_refl_step( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, a: Rc, b: Rc, id: String, @@ -236,20 +237,20 @@ impl Elaborator { self.add_new_step(step) } - pub fn elaborate_deep_eq( + pub fn elaborate_polyeq( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, root_id: &str, a: Rc, b: Rc, is_alpha_equivalence: bool, ) -> (usize, usize) { - DeepEqElaborator::new(self, root_id, is_alpha_equivalence).elaborate(pool, a, b) + PolyeqElaborator::new(self, root_id, is_alpha_equivalence).elaborate(pool, a, b) } pub fn elaborate_assume( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, premise: Rc, term: Rc, id: &str, @@ -261,7 +262,7 @@ impl Elaborator { }, false, ); - let equality_step = self.elaborate_deep_eq(pool, id, premise.clone(), term.clone(), false); + let equality_step = self.elaborate_polyeq(pool, id, premise.clone(), term.clone(), false); let equiv1_step = { let new_id = self.get_new_id(id); let clause = vec![build_term!(pool, (not {premise.clone()})), term.clone()]; diff --git a/carcara/src/checker/elaboration/deep_eq.rs b/carcara/src/elaborator/polyeq.rs similarity index 82% rename from carcara/src/checker/elaboration/deep_eq.rs rename to carcara/src/elaborator/polyeq.rs index 3b3809e1..b4b1c483 100644 --- a/carcara/src/checker/elaboration/deep_eq.rs +++ b/carcara/src/elaborator/polyeq.rs @@ -1,32 +1,36 @@ use super::*; use crate::{ ast::*, - checker::context::ContextStack, - utils::{DedupIterator, SymbolTable}, + utils::{DedupIterator, HashMapStack}, }; -pub struct DeepEqElaborator<'a> { +pub struct PolyeqElaborator<'a> { inner: &'a mut Elaborator, root_id: &'a str, - cache: SymbolTable<(Rc, Rc), (usize, usize)>, - checker: DeepEqualityChecker, + cache: HashMapStack<(Rc, Rc), (usize, usize)>, + checker: PolyeqComparator, context: Option, } -impl<'a> DeepEqElaborator<'a> { +impl<'a> PolyeqElaborator<'a> { pub fn new(inner: &'a mut Elaborator, root_id: &'a str, is_alpha_equivalence: bool) -> Self { Self { inner, root_id, - cache: SymbolTable::new(), - checker: DeepEqualityChecker::new(true, is_alpha_equivalence), + cache: HashMapStack::new(), + checker: PolyeqComparator::new(true, is_alpha_equivalence), context: is_alpha_equivalence.then(ContextStack::new), } } /// Takes two terms that are equal modulo reordering of equalities, and returns a premise that /// proves their equality. - pub fn elaborate(&mut self, pool: &mut TermPool, a: Rc, b: Rc) -> (usize, usize) { + pub fn elaborate( + &mut self, + pool: &mut dyn TermPool, + a: Rc, + b: Rc, + ) -> (usize, usize) { // TODO: Make this method return an error instead of panicking if the terms aren't equal let key = (a, b); @@ -40,7 +44,12 @@ impl<'a> DeepEqElaborator<'a> { result } - fn elaborate_impl(&mut self, pool: &mut TermPool, a: Rc, b: Rc) -> (usize, usize) { + fn elaborate_impl( + &mut self, + pool: &mut dyn TermPool, + a: Rc, + b: Rc, + ) -> (usize, usize) { if self.directly_eq(pool, &a, &b) { let id = self.inner.get_new_id(self.root_id); return self.inner.add_refl_step(pool, a, b, id); @@ -48,7 +57,7 @@ impl<'a> DeepEqElaborator<'a> { if let Some((a_left, a_right)) = match_term!((= x y) = a) { if let Some((b_left, b_right)) = match_term!((= x y) = b) { - if self.deep_eq(pool, a_left, b_right) && self.deep_eq(pool, a_right, b_left) { + if self.polyeq(pool, a_left, b_right) && self.polyeq(pool, a_right, b_left) { let [a_left, a_right, b_left, b_right] = [a_left, a_right, b_left, b_right].map(Clone::clone); return self.flip_equality(pool, (a, a_left, a_right), (b, b_left, b_right)); @@ -83,7 +92,7 @@ impl<'a> DeepEqElaborator<'a> { }) .collect(); - (a_bindings.as_slice().to_vec(), assignment_args) + (a_bindings.to_vec(), assignment_args) } Some(c) => { assert!(a_bindings @@ -103,7 +112,9 @@ impl<'a> DeepEqElaborator<'a> { .map(|((a_var, _), b)| (a_var.clone(), pool.add(b.clone().into()))) .collect(); - c.push(pool, &assigment_args, &variable_args).unwrap(); + let new_context_id = c.force_new_context(); + c.push(pool, &assigment_args, &variable_args, new_context_id) + .unwrap(); (variable_args, assigment_args) } }; @@ -133,10 +144,7 @@ impl<'a> DeepEqElaborator<'a> { let variable_args: Vec<_> = a_bindings .iter() - .map(|(name, value)| { - let sort = Term::Sort(pool.sort(value).clone()); - (name.clone(), pool.add(sort)) - }) + .map(|(name, value)| (name.clone(), pool.sort(value))) .collect(); self.open_subproof(); @@ -173,14 +181,14 @@ impl<'a> DeepEqElaborator<'a> { } // Since `choice` and `lambda` terms are not in the SMT-LIB standard, they cannot appear - // in the premises of a proof, so we would never need to elaborate deep equalities that + // in the premises of a proof, so we would never need to elaborate polyequalities that // use these terms. (Term::Choice(_, _), Term::Choice(_, _)) => { - log::error!("Trying to elaborate deep equality between `choice` terms"); + log::error!("Trying to elaborate polyequality between `choice` terms"); panic!() } (Term::Lambda(_, _), Term::Lambda(_, _)) => { - log::error!("Trying to elaborate deep equality between `lambda` terms"); + log::error!("Trying to elaborate polyequality between `lambda` terms"); panic!() } _ => panic!("terms not equal!"), @@ -188,7 +196,7 @@ impl<'a> DeepEqElaborator<'a> { } /// Returns `true` if the terms are directly equal, modulo application of the current context. - fn directly_eq(&mut self, pool: &mut TermPool, a: &Rc, b: &Rc) -> bool { + fn directly_eq(&mut self, pool: &mut dyn TermPool, a: &Rc, b: &Rc) -> bool { match &mut self.context { Some(c) => c.apply(pool, a) == *b, None => a == b, @@ -197,16 +205,16 @@ impl<'a> DeepEqElaborator<'a> { /// Returns `true` if the terms are equal modulo reordering of inequalities, and modulo /// application of the current context. - fn deep_eq(&mut self, pool: &mut TermPool, a: &Rc, b: &Rc) -> bool { + fn polyeq(&mut self, pool: &mut dyn TermPool, a: &Rc, b: &Rc) -> bool { match &mut self.context { - Some(c) => DeepEq::eq(&mut self.checker, &c.apply(pool, a), b), - None => DeepEq::eq(&mut self.checker, a, b), + Some(c) => Polyeq::eq(&mut self.checker, &c.apply(pool, a), b), + None => Polyeq::eq(&mut self.checker, a, b), } } fn build_cong( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, (a, b): (&Rc, &Rc), (a_args, b_args): (&[Rc], &[Rc]), ) -> (usize, usize) { @@ -236,7 +244,7 @@ impl<'a> DeepEqElaborator<'a> { fn flip_equality( &mut self, - pool: &mut TermPool, + pool: &mut dyn TermPool, (a, a_left, a_right): (Rc, Rc, Rc), (b, b_left, b_right): (Rc, Rc, Rc), ) -> (usize, usize) { @@ -253,12 +261,13 @@ impl<'a> DeepEqElaborator<'a> { // reordering of equalities, or if they are equal modulo the application of the current // context (in the case of alpha equivalence). // - // In this case, we need to elaborate the deep equality between x and x' (or y and y'), and - // from that, prove that `(= (= x y) (= y' x))`. We do that by first proving that `(= x x')` - // (1) and `(= y y')` (2). Then, we introduce a `cong` step that uses (1) and (2) to show - // that `(= (= x y) (= x' y'))` (3). After that, we add an `equiv_simplify` step that - // derives `(= (= x' y') (= y' x'))` (4). Finally, we introduce a `trans` step with premises - // (3) and (4) that proves `(= (= x y) (= y' x'))`. The general format looks like this: + // In this case, we need to elaborate the polyequality between x and x' (or y and y'), and + // from that, prove that `(= (= x y) (= y' x'))`. We do that by first proving that + // `(= x x')` (1) and `(= y y')` (2). Then, we introduce a `cong` step that uses (1) and (2) + // to show that `(= (= x y) (= x' y'))` (3). After that, we add an `equiv_simplify` step + // that derives `(= (= x' y') (= y' x'))` (4). Finally, we introduce a `trans` step with + // premises (3) and (4) that proves `(= (= x y) (= y' x'))`. The general format looks like + // this: // // ... // (step t1 (cl (= x x')) ...) @@ -296,11 +305,20 @@ impl<'a> DeepEqElaborator<'a> { (&[a_left, a_right], &[b_right, b_left]), ); + // It might be the case that `x'` is syntactically equal to `y'`, which would mean that we + // are adding an `equiv_simplify` step to prove a reflexivity step. This is not valid + // according to the `equiv_simplify` specification, so we must change the rule to `refl` in + // this case. + let rule = if b == flipped_b { + "refl".to_owned() + } else { + "equiv_simplify".to_owned() + }; let id = self.inner.get_new_id(self.root_id); let equiv_step = self.inner.add_new_step(ProofStep { id, clause: vec![build_term!(pool, (= {flipped_b} {b.clone()}))], - rule: "equiv_simplify".to_owned(), + rule, premises: Vec::new(), args: Vec::new(), discharge: Vec::new(), @@ -339,7 +357,11 @@ impl<'a> DeepEqElaborator<'a> { /// Creates the subproof for a `bind` or `bind_let` step, used to derive the equality of /// quantifier or `let` terms. This assumes the accumulator subproof has already been opened. - fn create_bind_subproof(&mut self, pool: &mut TermPool, inner_equality: (Rc, Rc)) { + fn create_bind_subproof( + &mut self, + pool: &mut dyn TermPool, + inner_equality: (Rc, Rc), + ) { let (a, b) = inner_equality; let inner_eq = self.elaborate(pool, a.clone(), b.clone()); diff --git a/carcara/src/checker/elaboration/pruning.rs b/carcara/src/elaborator/pruning.rs similarity index 52% rename from carcara/src/checker/elaboration/pruning.rs rename to carcara/src/elaborator/pruning.rs index a55ae7a6..5c5de3c2 100644 --- a/carcara/src/checker/elaboration/pruning.rs +++ b/carcara/src/elaborator/pruning.rs @@ -10,77 +10,92 @@ struct Frame<'a> { /// The index of the subproof that this frame represents, in the outer subproof index_of_subproof: usize, - visited: Vec, + + /// For each command, the distance between it and the source. + distance_to_source: Vec, + + /// The queue of commands to visit, represented as a tuple of (command index, distance to + /// source) + queue: VecDeque<(usize, usize)>, } pub fn prune_proof(proof: &[ProofCommand]) -> ProofDiff { - assert!(!proof.is_empty(), "cannot prune an empty proof"); - let end_step = proof .iter() .position(|c| c.clause().is_empty()) .expect("proof does not reach empty clause"); - let root = Frame { + slice_proof(proof, end_step, None) +} + +pub fn slice_proof( + proof: &[ProofCommand], + source: usize, + max_distance: Option, +) -> ProofDiff { + assert!(proof.len() > source, "invalid slice index"); + + let mut stack = vec![Frame { commands: proof, subproof_diffs: vec![None; proof.len()], - visited: vec![false; proof.len()], + distance_to_source: vec![usize::MAX; proof.len()], index_of_subproof: 0, // For the root proof, this value is irrelevant - }; - let mut stack = vec![root]; - let mut to_visit = vec![VecDeque::from([end_step])]; + queue: VecDeque::from([(source, 0usize)]), + }]; loop { 'inner: loop { let frame = stack.last_mut().unwrap(); - let Some(current) = to_visit.last_mut().unwrap().pop_front() else { + let Some((current, current_dist)) = frame.queue.pop_front() else { break 'inner; }; - if frame.visited[current] { + if frame.distance_to_source[current] < usize::MAX { + continue; + } + frame.distance_to_source[current] = + std::cmp::min(frame.distance_to_source[current], current_dist); + + if max_distance.is_some_and(|max| current_dist > max) { continue; } - frame.visited[current] = true; match &frame.commands[current] { ProofCommand::Assume { .. } => (), ProofCommand::Step(s) => { for &(depth, i) in &s.premises { - to_visit[depth].push_back(i); + stack[depth].queue.push_back((i, current_dist + 1)); } } ProofCommand::Subproof(s) => { let n = s.commands.len(); - let mut visited = vec![false; n]; let mut new_queue = VecDeque::new(); - new_queue.push_back(n - 1); - - // Since the second to last command in a subproof may be implicitly referenced - // by the last command, we have to add it to the `to_visit` queue if it exists - if n >= 2 { - new_queue.push_back(n - 2); - } + new_queue.push_back((n - 1, current_dist)); - // Since `assume` commands in the subproof cannot be removed we need to always - // visit them. As they don't have any premises, we can just mark them as visited - // now + // Since `assume` commands in a subproof are implicitly referenced by the last + // step in the subproof, we must add them to the queue now for (i, command) in s.commands.iter().enumerate() { if command.is_assume() { - visited[i] = true; + new_queue.push_back((i, current_dist + 1)); } } + // The second to last command in a subproof is also implicitly referenced by the + // last step, so we also add it to the queue + if n >= 2 { + new_queue.push_back((n - 2, current_dist + 1)); + } + let frame = Frame { commands: &s.commands, subproof_diffs: vec![None; n], - visited, + distance_to_source: vec![usize::MAX; n], index_of_subproof: current, + queue: new_queue, }; stack.push(frame); - to_visit.push(new_queue); } } } - to_visit.pop(); let mut frame = stack.pop().unwrap(); let mut result_diff = Vec::new(); @@ -90,9 +105,19 @@ pub fn prune_proof(proof: &[ProofCommand]) -> ProofDiff { for i in 0..frame.commands.len() { new_indices.push((depth, i - num_pruned)); - if !frame.visited[i] { + if frame.distance_to_source[i] == usize::MAX { result_diff.push((i, CommandDiff::Delete)); num_pruned += 1; + } else if max_distance.is_some_and(|max| frame.distance_to_source[i] == max + 1) { + let new_command = ProofCommand::Step(ProofStep { + id: frame.commands[i].id().to_owned(), + clause: frame.commands[i].clause().to_vec(), + rule: "hole".to_owned(), + premises: Vec::new(), + args: Vec::new(), + discharge: Vec::new(), + }); + result_diff.push((i, CommandDiff::Step(vec![new_command]))); } else if let Some(diff) = frame.subproof_diffs[i].take() { result_diff.push((i, CommandDiff::Subproof(diff))); } diff --git a/carcara/src/lib.rs b/carcara/src/lib.rs index 2e8ceddb..ae70f6b0 100644 --- a/carcara/src/lib.rs +++ b/carcara/src/lib.rs @@ -26,6 +26,7 @@ #![warn(clippy::multiple_crate_versions)] #![warn(clippy::redundant_closure_for_method_calls)] #![warn(clippy::redundant_pub_crate)] +#![warn(clippy::redundant_type_annotations)] #![warn(clippy::semicolon_if_nothing_returned)] #![warn(clippy::str_to_string)] #![warn(clippy::string_to_string)] @@ -38,13 +39,15 @@ pub mod ast; pub mod benchmarking; pub mod checker; +pub mod elaborator; pub mod parser; mod utils; -use checker::error::CheckerError; -use parser::ParserError; -use parser::Position; +use crate::benchmarking::{CollectResults, OnlineBenchmarkResults, RunMeasurement}; +use checker::{error::CheckerError, CheckerStatistics}; +use parser::{ParserError, Position}; use std::io; +use std::time::{Duration, Instant}; use thiserror::Error; pub type CarcaraResult = Result; @@ -70,11 +73,11 @@ pub struct CarcaraOptions { /// to a function that expects a `Real` will still be an error. pub allow_int_real_subtyping: bool, - /// Enable checking/elaboration of `lia_generic` steps using cvc5. When checking a proof, this - /// will call cvc5 to solve the linear integer arithmetic problem, check the proof, and discard - /// it. When elaborating, the proof will instead be inserted in the place of the `lia_generic` - /// step. - pub lia_via_cvc5: bool, + /// If `Some`, enables the checking/elaboration of `lia_generic` steps using an external solver. + /// When checking a proof, this means calling the solver to solve the linear integer arithmetic + /// problem, checking the proof, and discarding it. When elaborating, the proof will instead be + /// inserted in the place of the `lia_generic` step. See [`LiaGenericOptions`] for more details. + pub lia_options: Option, /// Enables "strict" checking of some rules. /// @@ -87,9 +90,25 @@ pub struct CarcaraOptions { /// benefit). pub strict: bool, - /// If `true`, Carcara will skip any rules that it does not recognize, and will consider them as + /// If `true`, Carcara will skip any steps with rules that it does not recognize, and will consider them as /// holes. Normally, using an unknown rule is considered an error. - pub skip_unknown_rules: bool, + pub ignore_unknown_rules: bool, + + /// If `true`, Carcará will log the check and elaboration statistics of any + /// `check` or `check_and_elaborate` run. If `false` no statistics are logged. + pub stats: bool, +} + +/// The options that control how `lia_generic` steps are checked/elaborated using an external +/// solver. +#[derive(Debug, Clone)] +pub struct LiaGenericOptions { + /// The external solver path. The solver should be a binary that can read SMT-LIB from stdin and + /// output an Alethe proof to stdout. + pub solver: Box, + + /// The arguments to pass to the solver. + pub arguments: Vec>, } impl CarcaraOptions { @@ -130,19 +149,134 @@ pub enum Error { } pub fn check(problem: T, proof: T, options: CarcaraOptions) -> Result { - let (prelude, proof, mut pool) = parser::parse_instance( - problem, - proof, - options.apply_function_defs, - options.expand_lets, - options.allow_int_real_subtyping, - )?; + let mut run_measures: RunMeasurement = RunMeasurement::default(); + + // Parsing + let total = Instant::now(); + let config = parser::Config { + apply_function_defs: options.apply_function_defs, + expand_lets: options.expand_lets, + allow_int_real_subtyping: options.allow_int_real_subtyping, + }; + let (prelude, proof, mut pool) = parser::parse_instance(problem, proof, config)?; + run_measures.parsing = total.elapsed(); + + let config = checker::Config::new() + .strict(options.strict) + .ignore_unknown_rules(options.ignore_unknown_rules) + .lia_options(options.lia_options); + + // Checking + let checking = Instant::now(); + let mut checker = checker::ProofChecker::new(&mut pool, config, &prelude); + if options.stats { + let mut checker_stats = CheckerStatistics { + file_name: "this", + elaboration_time: Duration::ZERO, + polyeq_time: Duration::ZERO, + assume_time: Duration::ZERO, + assume_core_time: Duration::ZERO, + results: OnlineBenchmarkResults::new(), + }; + let res = checker.check_with_stats(&proof, &mut checker_stats); + + run_measures.checking = checking.elapsed(); + run_measures.total = total.elapsed(); + + checker_stats.results.add_run_measurement( + &("this".to_owned(), 0), + RunMeasurement { + parsing: run_measures.parsing, + checking: run_measures.checking, + elaboration: checker_stats.elaboration_time, + scheduling: run_measures.scheduling, + total: run_measures.total, + polyeq: checker_stats.polyeq_time, + assume: checker_stats.assume_time, + assume_core: checker_stats.assume_core_time, + }, + ); + // Print the statistics + checker_stats.results.print(false); + + res + } else { + checker.check(&proof) + } +} + +pub fn check_parallel( + problem: T, + proof: T, + options: CarcaraOptions, + num_threads: usize, + stack_size: usize, +) -> Result { + use crate::checker::Scheduler; + use std::sync::Arc; + let mut run_measures: RunMeasurement = RunMeasurement::default(); + + // Parsing + let total = Instant::now(); + let config = parser::Config { + apply_function_defs: options.apply_function_defs, + expand_lets: options.expand_lets, + allow_int_real_subtyping: options.allow_int_real_subtyping, + }; + let (prelude, proof, pool) = parser::parse_instance(problem, proof, config)?; + run_measures.parsing = total.elapsed(); let config = checker::Config::new() .strict(options.strict) - .skip_unknown_rules(options.skip_unknown_rules) - .lia_via_cvc5(options.lia_via_cvc5); - checker::ProofChecker::new(&mut pool, config, prelude).check(&proof) + .ignore_unknown_rules(options.ignore_unknown_rules) + .lia_options(options.lia_options); + + // Checking + let checking = Instant::now(); + let (scheduler, schedule_context_usage) = Scheduler::new(num_threads, &proof); + run_measures.scheduling = checking.elapsed(); + let mut checker = checker::ParallelProofChecker::new( + Arc::new(pool), + config, + &prelude, + &schedule_context_usage, + stack_size, + ); + + if options.stats { + let mut checker_stats = CheckerStatistics { + file_name: "this", + elaboration_time: Duration::ZERO, + polyeq_time: Duration::ZERO, + assume_time: Duration::ZERO, + assume_core_time: Duration::ZERO, + results: OnlineBenchmarkResults::new(), + }; + let res = checker.check_with_stats(&proof, &scheduler, &mut checker_stats); + + run_measures.checking = checking.elapsed(); + run_measures.total = total.elapsed(); + + checker_stats.results.add_run_measurement( + &("this".to_owned(), 0), + RunMeasurement { + parsing: run_measures.parsing, + checking: run_measures.checking, + elaboration: checker_stats.elaboration_time, + scheduling: run_measures.scheduling, + total: run_measures.total, + polyeq: checker_stats.polyeq_time, + assume: checker_stats.assume_time, + assume_core: checker_stats.assume_core_time, + }, + ); + // Print the statistics + checker_stats.results.print(false); + + res + } else { + checker.check(&proof, &scheduler) + } } pub fn check_and_elaborate( @@ -150,17 +284,58 @@ pub fn check_and_elaborate( proof: T, options: CarcaraOptions, ) -> Result<(bool, ast::Proof), Error> { - let (prelude, proof, mut pool) = parser::parse_instance( - problem, - proof, - options.apply_function_defs, - options.expand_lets, - options.allow_int_real_subtyping, - )?; + let mut run_measures: RunMeasurement = RunMeasurement::default(); + + // Parsing + let total = Instant::now(); + let config = parser::Config { + apply_function_defs: options.apply_function_defs, + expand_lets: options.expand_lets, + allow_int_real_subtyping: options.allow_int_real_subtyping, + }; + let (prelude, proof, mut pool) = parser::parse_instance(problem, proof, config)?; + run_measures.parsing = total.elapsed(); let config = checker::Config::new() .strict(options.strict) - .skip_unknown_rules(options.skip_unknown_rules) - .lia_via_cvc5(options.lia_via_cvc5); - checker::ProofChecker::new(&mut pool, config, prelude).check_and_elaborate(proof) + .ignore_unknown_rules(options.ignore_unknown_rules) + .lia_options(options.lia_options); + + // Checking + let checking = Instant::now(); + let mut checker = checker::ProofChecker::new(&mut pool, config, &prelude); + if options.stats { + let mut checker_stats = CheckerStatistics { + file_name: "this", + elaboration_time: Duration::ZERO, + polyeq_time: Duration::ZERO, + assume_time: Duration::ZERO, + assume_core_time: Duration::ZERO, + results: OnlineBenchmarkResults::new(), + }; + + let res = checker.check_and_elaborate_with_stats(proof, &mut checker_stats); + run_measures.checking = checking.elapsed(); + run_measures.total = total.elapsed(); + + checker_stats.results.add_run_measurement( + &("this".to_owned(), 0), + RunMeasurement { + parsing: run_measures.parsing, + checking: run_measures.checking, + elaboration: checker_stats.elaboration_time, + scheduling: run_measures.scheduling, + total: run_measures.total, + polyeq: checker_stats.polyeq_time, + assume: checker_stats.assume_time, + assume_core: checker_stats.assume_core_time, + }, + ); + // Print the statistics + checker_stats.results.print(false); + + res + } else { + checker.check_and_elaborate(proof) + } } diff --git a/carcara/src/parser/error.rs b/carcara/src/parser/error.rs index 0b38465f..f2ea52ec 100644 --- a/carcara/src/parser/error.rs +++ b/carcara/src/parser/error.rs @@ -1,7 +1,7 @@ //! The types for parser errors. use crate::{ - ast::{Identifier, Sort}, + ast::{Ident, Sort}, parser::Token, utils::Range, }; @@ -55,7 +55,7 @@ pub enum ParserError { /// The parser encountered an identifier that was not defined. #[error("identifier '{0}' is not defined")] - UndefinedIden(Identifier), + UndefinedIden(Ident), /// The parser encountered a sort that was not defined. #[error("sort '{0}' is not defined")] diff --git a/carcara/src/parser/mod.rs b/carcara/src/parser/mod.rs index b48ed566..b628944f 100644 --- a/carcara/src/parser/mod.rs +++ b/carcara/src/parser/mod.rs @@ -9,14 +9,27 @@ pub use lexer::{Lexer, Position, Reserved, Token}; use crate::{ ast::*, - utils::{HashCache, SymbolTable}, + utils::{HashCache, HashMapStack}, CarcaraResult, Error, }; -use ahash::{AHashMap, AHashSet}; use error::assert_num_args; +use indexmap::{IndexMap, IndexSet}; use rug::Integer; use std::{io::BufRead, str::FromStr}; +#[derive(Debug, Default, Clone, Copy)] +pub struct Config { + pub apply_function_defs: bool, + pub expand_lets: bool, + pub allow_int_real_subtyping: bool, +} + +impl Config { + pub fn new() -> Self { + Self::default() + } +} + /// Parses an SMT problem instance (in the SMT-LIB format) and its associated proof (in the Alethe /// format). /// @@ -25,18 +38,10 @@ use std::{io::BufRead, str::FromStr}; pub fn parse_instance( problem: T, proof: T, - apply_function_defs: bool, - expand_lets: bool, - allow_int_real_subtyping: bool, -) -> CarcaraResult<(ProblemPrelude, Proof, TermPool)> { - let mut pool = TermPool::new(); - let mut parser = Parser::new( - &mut pool, - problem, - apply_function_defs, - expand_lets, - allow_int_real_subtyping, - )?; + config: Config, +) -> CarcaraResult<(ProblemPrelude, Proof, PrimitivePool)> { + let mut pool = PrimitivePool::new(); + let mut parser = Parser::new(&mut pool, config, problem)?; let (prelude, premises) = parser.parse_problem()?; parser.reset(proof)?; let commands = parser.parse_proof()?; @@ -72,56 +77,46 @@ enum AnchorArg { /// pool used by the parser. #[derive(Default)] struct ParserState { - symbol_table: SymbolTable, Rc>, - function_defs: AHashMap, - sort_declarations: AHashMap, - step_ids: SymbolTable, usize>, + symbol_table: HashMapStack, Rc>, + function_defs: IndexMap, + sort_declarations: IndexMap, + step_ids: HashMapStack, usize>, } /// A parser for the Alethe proof format. pub struct Parser<'a, R> { - pool: &'a mut TermPool, + pool: &'a mut PrimitivePool, + config: Config, lexer: Lexer, current_token: Token, current_position: Position, state: ParserState, interpret_integers_as_reals: bool, - apply_function_defs: bool, - expand_lets: bool, - problem: Option<(ProblemPrelude, AHashSet>)>, - allow_int_real_subtyping: bool, + problem: Option<(ProblemPrelude, IndexSet>)>, } impl<'a, R: BufRead> Parser<'a, R> { /// Constructs a new `Parser` from a type that implements `BufRead`. /// /// This operation can fail if there is an IO or lexer error on the first token. - pub fn new( - pool: &'a mut TermPool, - input: R, - apply_function_defs: bool, - expand_lets: bool, - allow_int_real_subtyping: bool, - ) -> CarcaraResult { + pub fn new(pool: &'a mut PrimitivePool, config: Config, input: R) -> CarcaraResult { let mut state = ParserState::default(); let bool_sort = pool.add(Term::Sort(Sort::Bool)); for iden in ["true", "false"] { - let iden = HashCache::new(Identifier::Simple(iden.to_owned())); + let iden = HashCache::new(Ident::Simple(iden.to_owned())); state.symbol_table.insert(iden, bool_sort.clone()); } let mut lexer = Lexer::new(input)?; let (current_token, current_position) = lexer.next_token()?; Ok(Parser { pool, + config, lexer, current_token, current_position, state, interpret_integers_as_reals: false, - apply_function_defs, - expand_lets, problem: None, - allow_int_real_subtyping, }) } @@ -150,7 +145,7 @@ impl<'a, R: BufRead> Parser<'a, R> { fn insert_sorted_var(&mut self, (symbol, sort): SortedVar) { self.state .symbol_table - .insert(HashCache::new(Identifier::Simple(symbol)), sort); + .insert(HashCache::new(Ident::Simple(symbol)), sort); } /// Shortcut for `self.problem.as_mut().unwrap().0` @@ -159,20 +154,18 @@ impl<'a, R: BufRead> Parser<'a, R> { } /// Shortcut for `self.problem.as_mut().unwrap().1` - fn premises(&mut self) -> &mut AHashSet> { + fn premises(&mut self) -> &mut IndexSet> { &mut self.problem.as_mut().unwrap().1 } /// Constructs and sort checks a variable term. - fn make_var(&mut self, iden: Identifier) -> Result, ParserError> { + fn make_var(&mut self, iden: Ident) -> Result, ParserError> { let cached = HashCache::new(iden); let sort = match self.state.symbol_table.get(&cached) { Some(s) => s.clone(), None => return Err(ParserError::UndefinedIden(cached.unwrap())), }; - Ok(self - .pool - .add(Term::Terminal(Terminal::Var(cached.unwrap(), sort)))) + Ok(self.pool.add(Term::Var(cached.unwrap(), sort))) } /// Constructs and sort checks an operation term. @@ -181,29 +174,34 @@ impl<'a, R: BufRead> Parser<'a, R> { match op { Operator::Not => { assert_num_args(&args, 1)?; - SortError::assert_eq(&Sort::Bool, sorts[0])?; + SortError::assert_eq(&Sort::Bool, sorts[0].as_sort().unwrap())?; } Operator::Implies => { assert_num_args(&args, 2..)?; for s in sorts { - SortError::assert_eq(&Sort::Bool, s)?; + SortError::assert_eq(&Sort::Bool, s.as_sort().unwrap())?; } } Operator::Or | Operator::And | Operator::Xor => { // These operators can be called with only one argument assert_num_args(&args, 1..)?; for s in sorts { - SortError::assert_eq(&Sort::Bool, s)?; + SortError::assert_eq(&Sort::Bool, s.as_sort().unwrap())?; } } Operator::Equals | Operator::Distinct => { assert_num_args(&args, 2..)?; - SortError::assert_all_eq(&sorts)?; + SortError::assert_all_eq( + &sorts + .iter() + .map(|op| op.as_sort().unwrap()) + .collect::>(), + )?; } Operator::Ite => { assert_num_args(&args, 3)?; - SortError::assert_eq(&Sort::Bool, sorts[0])?; - SortError::assert_eq(sorts[1], sorts[2])?; + SortError::assert_eq(&Sort::Bool, sorts[0].as_sort().unwrap())?; + SortError::assert_eq(sorts[1].as_sort().unwrap(), sorts[2].as_sort().unwrap())?; } Operator::Add | Operator::Sub | Operator::Mult => { // The `-` operator, in particular, can be called with only one argument, in which @@ -216,62 +214,80 @@ impl<'a, R: BufRead> Parser<'a, R> { // All the arguments must be either Int or Real. Also, if we are not allowing // Int/Real subtyping, all arguments must have the same sort - if self.allow_int_real_subtyping { + if self.config.allow_int_real_subtyping { for s in sorts { - SortError::assert_one_of(&[Sort::Int, Sort::Real], s)?; + SortError::assert_one_of(&[Sort::Int, Sort::Real], s.as_sort().unwrap())?; } } else { - SortError::assert_one_of(&[Sort::Int, Sort::Real], sorts[0])?; - SortError::assert_all_eq(&sorts)?; + SortError::assert_one_of( + &[Sort::Int, Sort::Real], + sorts[0].as_sort().unwrap(), + )?; + SortError::assert_all_eq( + &sorts + .iter() + .map(|op| op.as_sort().unwrap()) + .collect::>(), + )?; } } Operator::IntDiv => { assert_num_args(&args, 2..)?; - SortError::assert_eq(&Sort::Int, sorts[0])?; - SortError::assert_all_eq(&sorts)?; + SortError::assert_eq(&Sort::Int, sorts[0].as_sort().unwrap())?; + SortError::assert_all_eq( + &sorts + .iter() + .map(|op| op.as_sort().unwrap()) + .collect::>(), + )?; } Operator::RealDiv => { assert_num_args(&args, 2..)?; // Normally, the `/` operator may only receive Real arguments, but if we are // allowing Int/Real subtyping, it may also receive Ints - if self.allow_int_real_subtyping { + if self.config.allow_int_real_subtyping { for s in sorts { - SortError::assert_one_of(&[Sort::Int, Sort::Real], s)?; + SortError::assert_one_of(&[Sort::Int, Sort::Real], s.as_sort().unwrap())?; } } else { - SortError::assert_eq(&Sort::Real, sorts[0])?; - SortError::assert_all_eq(&sorts)?; + SortError::assert_eq(&Sort::Real, sorts[0].as_sort().unwrap())?; + SortError::assert_all_eq( + &sorts + .iter() + .map(|op| op.as_sort().unwrap()) + .collect::>(), + )?; } } Operator::Mod => { assert_num_args(&args, 2)?; - SortError::assert_eq(&Sort::Int, sorts[0])?; - SortError::assert_eq(&Sort::Int, sorts[1])?; + SortError::assert_eq(&Sort::Int, sorts[0].as_sort().unwrap())?; + SortError::assert_eq(&Sort::Int, sorts[1].as_sort().unwrap())?; } Operator::Abs => { assert_num_args(&args, 1)?; - SortError::assert_eq(&Sort::Int, sorts[0])?; + SortError::assert_eq(&Sort::Int, sorts[0].as_sort().unwrap())?; } Operator::LessThan | Operator::GreaterThan | Operator::LessEq | Operator::GreaterEq => { assert_num_args(&args, 2..)?; // All the arguments must be either Int or Real sorted, but they don't need to all // have the same sort for s in sorts { - SortError::assert_one_of(&[Sort::Int, Sort::Real], s)?; + SortError::assert_one_of(&[Sort::Int, Sort::Real], s.as_sort().unwrap())?; } } Operator::ToReal => { assert_num_args(&args, 1)?; - SortError::assert_eq(&Sort::Int, sorts[0])?; + SortError::assert_eq(&Sort::Int, sorts[0].as_sort().unwrap())?; } Operator::ToInt | Operator::IsInt => { assert_num_args(&args, 1)?; - SortError::assert_eq(&Sort::Real, sorts[0])?; + SortError::assert_eq(&Sort::Real, sorts[0].as_sort().unwrap())?; } Operator::Select => { assert_num_args(&args, 2)?; - match sorts[0] { + match sorts[0].as_sort().unwrap() { Sort::Array(_, _) => (), got => { // Instead of creating some special case for sort errors with parametric @@ -279,7 +295,7 @@ impl<'a, R: BufRead> Parser<'a, R> { // infer the `X` sort from the second operator argument. This may be // changed later let got = got.clone(); - let x = sorts[1].clone(); + let x = sorts[1].as_sort().unwrap().clone(); let x = self.pool.add(Term::Sort(x)); let y = self .pool @@ -294,14 +310,15 @@ impl<'a, R: BufRead> Parser<'a, R> { } Operator::Store => { assert_num_args(&args, 3)?; - match sorts[0] { + match sorts[0].as_sort().unwrap() { Sort::Array(x, y) => { - SortError::assert_eq(x.as_sort().unwrap(), sorts[1])?; - SortError::assert_eq(y.as_sort().unwrap(), sorts[2])?; + SortError::assert_eq(x.as_sort().unwrap(), sorts[1].as_sort().unwrap())?; + SortError::assert_eq(y.as_sort().unwrap(), sorts[2].as_sort().unwrap())?; } got => { let got = got.clone(); - let [x, y] = [sorts[0], sorts[1]].map(|s| Term::Sort(s.clone())); + let [x, y] = [&sorts[0], &sorts[1]] + .map(|s| Term::Sort(s.as_sort().unwrap().clone())); return Err(SortError { expected: vec![Sort::Array(self.pool.add(x), self.pool.add(y))], got, @@ -320,8 +337,9 @@ impl<'a, R: BufRead> Parser<'a, R> { function: Rc, args: Vec>, ) -> Result, ParserError> { + let sort = self.pool.sort(&function); let sorts = { - let function_sort = self.pool.sort(&function); + let function_sort = sort.as_sort().unwrap(); if let Sort::Function(sorts) = function_sort { sorts } else { @@ -331,7 +349,10 @@ impl<'a, R: BufRead> Parser<'a, R> { }; assert_num_args(&args, sorts.len() - 1)?; for i in 0..args.len() { - SortError::assert_eq(sorts[i].as_sort().unwrap(), self.pool.sort(&args[i]))?; + SortError::assert_eq( + sorts[i].as_sort().unwrap(), + self.pool.sort(&args[i]).as_sort().unwrap(), + )?; } Ok(self.pool.add(Term::App(function, args))) } @@ -468,8 +489,8 @@ impl<'a, R: BufRead> Parser<'a, R> { /// /// All other commands are ignored. This method returns a hash set containing the premises /// introduced in `assert` commands. - pub fn parse_problem(&mut self) -> CarcaraResult<(ProblemPrelude, AHashSet>)> { - self.problem = Some((ProblemPrelude::default(), AHashSet::new())); + pub fn parse_problem(&mut self) -> CarcaraResult<(ProblemPrelude, IndexSet>)> { + self.problem = Some((ProblemPrelude::default(), IndexSet::new())); while self.current_token != Token::Eof { self.expect_token(Token::OpenParen)?; @@ -502,7 +523,7 @@ impl<'a, R: BufRead> Parser<'a, R> { Token::ReservedWord(Reserved::DefineFun) => { let (name, func_def) = self.parse_define_fun()?; - if self.apply_function_defs { + if self.config.apply_function_defs { self.state.function_defs.insert(name, func_def); } else { // If `self.apply_function_defs` is false, we instead add the function name @@ -513,9 +534,7 @@ impl<'a, R: BufRead> Parser<'a, R> { self.pool .add(Term::Lambda(BindingList(func_def.params), func_def.body)) }; - let sort = self - .pool - .add(Term::Sort(self.pool.sort(&lambda_term).clone())); + let sort = self.pool.sort(&lambda_term); let var = (name, sort); self.insert_sorted_var(var.clone()); let var_term = self.pool.add(var.into()); @@ -559,6 +578,8 @@ impl<'a, R: BufRead> Parser<'a, R> { let mut commands_stack = vec![Vec::new()]; let mut end_step_stack = Vec::new(); let mut subproof_args_stack = Vec::new(); + let mut subproof_id_stack = Vec::new(); + let mut last_subproof_id: i64 = -1; let mut finished_assumes = false; @@ -596,6 +617,8 @@ impl<'a, R: BufRead> Parser<'a, R> { commands_stack.push(Vec::new()); end_step_stack.push(anchor.end_step_id); subproof_args_stack.push((anchor.assignment_args, anchor.variable_args)); + last_subproof_id += 1; + subproof_id_stack.push(last_subproof_id as usize); continue; } _ => return Err(Error::Parser(ParserError::UnexpectedToken(token), position)), @@ -617,6 +640,7 @@ impl<'a, R: BufRead> Parser<'a, R> { let commands = commands_stack.pop().unwrap(); end_step_stack.pop().unwrap(); let (assignment_args, variable_args) = subproof_args_stack.pop().unwrap(); + let subproof_id = subproof_id_stack.pop().unwrap(); // The subproof must contain at least two commands: the end step and the previous // command it implicitly references @@ -645,6 +669,7 @@ impl<'a, R: BufRead> Parser<'a, R> { commands, assignment_args, variable_args, + context_id: subproof_id, })); } self.state @@ -669,6 +694,7 @@ impl<'a, R: BufRead> Parser<'a, R> { fn parse_assume_command(&mut self) -> CarcaraResult<(String, Rc)> { let id = self.expect_symbol()?; let term = self.parse_term_expecting_sort(&Sort::Bool)?; + self.ignore_remaining_attributes()?; self.expect_token(Token::CloseParen)?; Ok((id, term)) } @@ -709,7 +735,7 @@ impl<'a, R: BufRead> Parser<'a, R> { Vec::new() }; - // For some rules (notable the `subproof` rule), there is also a `:discharge` attribute that + // For some rules (notably the `subproof` rule), there is also a `:discharge` attribute that // takes a series of command ids, in addition to the regular premises let discharge = if self.current_token == Token::Keyword("discharge".into()) { self.next_token()?; @@ -807,8 +833,7 @@ impl<'a, R: BufRead> Parser<'a, R> { self.next_token()?; let var = self.expect_symbol()?; let value = self.parse_term()?; - let sort = Term::Sort(self.pool.sort(&value).clone()); - let sort = self.pool.add(sort); + let sort = self.pool.sort(&value); self.insert_sorted_var((var.clone(), sort)); self.expect_token(Token::CloseParen)?; AnchorArg::Assign(var, value) @@ -923,10 +948,10 @@ impl<'a, R: BufRead> Parser<'a, R> { /// Parses a term. pub fn parse_term(&mut self) -> CarcaraResult> { let term = match self.next_token()? { - (Token::Numeral(n), _) if self.interpret_integers_as_reals => Term::real(n), - (Token::Numeral(n), _) => Term::integer(n), - (Token::Decimal(r), _) => Term::real(r), - (Token::String(s), _) => Term::string(s), + (Token::Numeral(n), _) if self.interpret_integers_as_reals => Term::new_real(n), + (Token::Numeral(n), _) => Term::new_int(n), + (Token::Decimal(r), _) => Term::new_real(r), + (Token::String(s), _) => Term::new_string(s), (Token::Symbol(s), pos) => { // Check to see if there is a nullary function defined with this name return Ok(if let Some(func_def) = self.state.function_defs.get(&s) { @@ -939,7 +964,7 @@ impl<'a, R: BufRead> Parser<'a, R> { )); } } else { - self.make_var(Identifier::Simple(s)) + self.make_var(Ident::Simple(s)) .map_err(|err| Error::Parser(err, pos))? }); } @@ -953,7 +978,7 @@ impl<'a, R: BufRead> Parser<'a, R> { fn parse_term_expecting_sort(&mut self, expected_sort: &Sort) -> CarcaraResult> { let pos = self.current_position; let term = self.parse_term()?; - SortError::assert_eq(expected_sort, self.pool.sort(&term)) + SortError::assert_eq(expected_sort, self.pool.sort(&term).as_sort().unwrap()) .map_err(|e| Error::Parser(e.into(), pos))?; Ok(term) } @@ -1020,7 +1045,7 @@ impl<'a, R: BufRead> Parser<'a, R> { p.expect_token(Token::OpenParen)?; let name = p.expect_symbol()?; let value = p.parse_term()?; - let sort = p.pool.add(Term::Sort(p.pool.sort(&value).clone())); + let sort = p.pool.sort(&value); p.insert_sorted_var((name.clone(), sort)); p.expect_token(Token::CloseParen)?; Ok((name, value)) @@ -1031,12 +1056,11 @@ impl<'a, R: BufRead> Parser<'a, R> { self.expect_token(Token::CloseParen)?; self.state.symbol_table.pop_scope(); - if self.expand_lets { + if self.config.expand_lets { let substitution = bindings .into_iter() .map(|(name, value)| { - let sort = Term::Sort(self.pool.sort(&value).clone()); - let var = Term::var(name, self.pool.add(sort)); + let var = Term::new_var(name, self.pool.sort(&value)); (self.pool.add(var), value) }) .collect(); @@ -1142,8 +1166,11 @@ impl<'a, R: BufRead> Parser<'a, R> { assert_num_args(&args, func.params.len()) .map_err(|err| Error::Parser(err, head_pos))?; for (arg, param) in args.iter().zip(func.params.iter()) { - SortError::assert_eq(param.1.as_sort().unwrap(), self.pool.sort(arg)) - .map_err(|err| Error::Parser(err.into(), head_pos))?; + SortError::assert_eq( + param.1.as_sort().unwrap(), + self.pool.sort(arg).as_sort().unwrap(), + ) + .map_err(|err| Error::Parser(err.into(), head_pos))?; } // Build a hash map of all the parameter names and the values they will @@ -1152,7 +1179,7 @@ impl<'a, R: BufRead> Parser<'a, R> { .params .iter() .zip(args) - .map(|((name, sort), arg)| (self.pool.add(Term::var(name, sort.clone())), arg)) + .map(|((n, s), arg)| (self.pool.add(Term::new_var(n, s.clone())), arg)) .collect(); // Since we already checked the sorts of the arguments, creating this substitution diff --git a/carcara/src/parser/tests.rs b/carcara/src/parser/tests.rs index 90394ef1..e3651c00 100644 --- a/carcara/src/parser/tests.rs +++ b/carcara/src/parser/tests.rs @@ -3,16 +3,23 @@ #![cfg(test)] use super::*; +use crate::ast::pool::PrimitivePool; const ERROR_MESSAGE: &str = "parser error during test"; +const TEST_CONFIG: Config = Config { + // Some tests need function definitions to be applied + apply_function_defs: true, + expand_lets: false, + allow_int_real_subtyping: false, +}; + pub fn parse_terms( - pool: &mut TermPool, + pool: &mut PrimitivePool, definitions: &str, terms: [&str; N], ) -> [Rc; N] { - let mut parser = - Parser::new(pool, definitions.as_bytes(), true, false, false).expect(ERROR_MESSAGE); + let mut parser = Parser::new(pool, TEST_CONFIG, definitions.as_bytes()).expect(ERROR_MESSAGE); parser.parse_problem().expect(ERROR_MESSAGE); terms.map(|s| { @@ -21,8 +28,8 @@ pub fn parse_terms( }) } -pub fn parse_term(pool: &mut TermPool, input: &str) -> Rc { - Parser::new(pool, input.as_bytes(), true, false, false) +pub fn parse_term(pool: &mut PrimitivePool, input: &str) -> Rc { + Parser::new(pool, TEST_CONFIG, input.as_bytes()) .and_then(|mut parser| parser.parse_term()) .expect(ERROR_MESSAGE) } @@ -30,22 +37,22 @@ pub fn parse_term(pool: &mut TermPool, input: &str) -> Rc { /// Tries to parse a term from a `&str`, expecting it to fail. Returns the error encountered, or /// panics if no error is encountered. pub fn parse_term_err(input: &str) -> Error { - let mut pool = TermPool::new(); - Parser::new(&mut pool, input.as_bytes(), true, false, false) + let mut pool = PrimitivePool::new(); + Parser::new(&mut pool, TEST_CONFIG, input.as_bytes()) .and_then(|mut p| p.parse_term()) .expect_err("expected error") } /// Parses a proof from a `&str`. Panics if any error is encountered. -pub fn parse_proof(pool: &mut TermPool, input: &str) -> Proof { - let commands = Parser::new(pool, input.as_bytes(), true, false, false) +pub fn parse_proof(pool: &mut PrimitivePool, input: &str) -> Proof { + let commands = Parser::new(pool, TEST_CONFIG, input.as_bytes()) .expect(ERROR_MESSAGE) .parse_proof() .expect(ERROR_MESSAGE); - Proof { premises: AHashSet::new(), commands } + Proof { premises: IndexSet::new(), commands } } -fn run_parser_tests(pool: &mut TermPool, cases: &[(&str, Rc)]) { +fn run_parser_tests(pool: &mut PrimitivePool, cases: &[(&str, Rc)]) { for (case, expected) in cases { let got = parse_term(pool, case); assert_eq!(expected, &got); @@ -54,9 +61,9 @@ fn run_parser_tests(pool: &mut TermPool, cases: &[(&str, Rc)]) { #[test] fn test_hash_consing() { - use ahash::AHashSet; + use indexmap::IndexSet; - let mut pool = TermPool::new(); + let mut pool = PrimitivePool::new(); let input = "(- (- (+ 1 2) @@ -64,7 +71,7 @@ fn test_hash_consing() { ) (* 2 2) )"; - let mut parser = Parser::new(&mut pool, input.as_bytes(), true, false, false).unwrap(); + let mut parser = Parser::new(&mut pool, Config::new(), input.as_bytes()).unwrap(); parser.parse_term().unwrap(); // We expect this input to result in 7 unique terms after parsing: @@ -82,6 +89,7 @@ fn test_hash_consing() { "true", "false", "1", + "Int", "2", "(+ 1 2)", "(* (+ 1 2) (+ 1 2))", @@ -90,11 +98,11 @@ fn test_hash_consing() { "(- (- (+ 1 2) (* (+ 1 2) (+ 1 2))) (* 2 2))", ] .into_iter() - .collect::>(); - - assert_eq!(pool.terms.len(), expected.len()); + .collect::>(); - for got in pool.terms.keys() { + let pool_terms = pool.storage.into_vec(); + assert_eq!(pool_terms.len(), expected.len()); + for got in pool_terms { let formatted: &str = &format!("{:#}", got); assert!(expected.contains(formatted), "{}", formatted); } @@ -102,16 +110,16 @@ fn test_hash_consing() { #[test] fn test_constant_terms() { - let mut p = TermPool::new(); - assert_eq!(Term::integer(42), *parse_term(&mut p, "42")); - assert_eq!(Term::real((3, 2)), *parse_term(&mut p, "1.5")); - assert_eq!(Term::string("foo"), *parse_term(&mut p, "\"foo\"")); + let mut p = PrimitivePool::new(); + assert_eq!(Term::new_int(42), *parse_term(&mut p, "42")); + assert_eq!(Term::new_real((3, 2)), *parse_term(&mut p, "1.5")); + assert_eq!(Term::new_string("foo"), *parse_term(&mut p, "\"foo\"")); } #[test] fn test_arithmetic_ops() { - let mut p = TermPool::new(); - let [one, two, three, five, seven] = [1, 2, 3, 5, 7].map(|n| p.add(Term::integer(n))); + let mut p = PrimitivePool::new(); + let [one, two, three, five, seven] = [1, 2, 3, 5, 7].map(|n| p.add(Term::new_int(n))); let cases = [ ( "(+ 2 3)", @@ -140,8 +148,8 @@ fn test_arithmetic_ops() { #[test] fn test_logic_ops() { - let mut p = TermPool::new(); - let [zero, one, two, three, four] = [0, 1, 2, 3, 4].map(|n| p.add(Term::integer(n))); + let mut p = PrimitivePool::new(); + let [zero, one, two, three, four] = [0, 1, 2, 3, 4].map(|n| p.add(Term::new_int(n))); let cases = [ ( "(and true false)", @@ -222,8 +230,8 @@ fn test_logic_ops() { #[test] fn test_ite() { - let mut p = TermPool::new(); - let [one, two, three] = [1, 2, 3].map(|n| p.add(Term::integer(n))); + let mut p = PrimitivePool::new(); + let [one, two, three] = [1, 2, 3].map(|n| p.add(Term::new_int(n))); let cases = [ ( "(ite true 2 3)", @@ -259,12 +267,12 @@ fn test_ite() { #[test] fn test_quantifiers() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let bool_sort = p.add(Term::Sort(Sort::Bool)); let real_sort = p.add(Term::Sort(Sort::Real)); let cases = [ ("(exists ((p Bool)) p)", { - let inner = p.add(Term::var("p", bool_sort.clone())); + let inner = p.add(Term::new_var("p", bool_sort.clone())); p.add(Term::Quant( Quantifier::Exists, BindingList(vec![("p".into(), bool_sort)]), @@ -272,9 +280,9 @@ fn test_quantifiers() { )) }), ("(forall ((x Real) (y Real)) (= (+ x y) 0.0))", { - let [x, y] = ["x", "y"].map(|s| p.add(Term::var(s, real_sort.clone()))); + let [x, y] = ["x", "y"].map(|s| p.add(Term::new_var(s, real_sort.clone()))); let x_plus_y = p.add(Term::Op(Operator::Add, vec![x, y])); - let zero = p.add(Term::real(0)); + let zero = p.add(Term::new_real(0)); let inner = p.add(Term::Op(Operator::Equals, vec![x_plus_y, zero])); p.add(Term::Quant( Quantifier::Forall, @@ -299,17 +307,17 @@ fn test_quantifiers() { #[test] fn test_choice_terms() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let bool_sort = p.add(Term::Sort(Sort::Bool)); let int_sort = p.add(Term::Sort(Sort::Int)); let cases = [ ("(choice ((p Bool)) p)", { - let inner = p.add(Term::var("p", bool_sort.clone())); + let inner = p.add(Term::new_var("p", bool_sort.clone())); p.add(Term::Choice(("p".into(), bool_sort), inner)) }), ("(choice ((x Int)) (= x 0))", { - let x = p.add(Term::var("x", int_sort.clone())); - let zero = p.add(Term::integer(0)); + let x = p.add(Term::new_var("x", int_sort.clone())); + let zero = p.add(Term::new_int(0)); let inner = p.add(Term::Op(Operator::Equals, vec![x, zero])); p.add(Term::Choice(("x".into(), int_sort), inner)) }), @@ -327,20 +335,20 @@ fn test_choice_terms() { #[test] fn test_let_terms() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let int_sort = p.add(Term::Sort(Sort::Int)); let bool_sort = p.add(Term::Sort(Sort::Bool)); let cases = [ ("(let ((p false)) p)", { - let inner = p.add(Term::var("p", bool_sort)); + let inner = p.add(Term::new_var("p", bool_sort)); p.add(Term::Let( BindingList(vec![("p".into(), p.bool_false())]), inner, )) }), ("(let ((x 1) (y 2)) (+ x y))", { - let [one, two] = [1, 2].map(|n| p.add(Term::integer(n))); - let [x, y] = ["x", "y"].map(|s| p.add(Term::var(s, int_sort.clone()))); + let [one, two] = [1, 2].map(|n| p.add(Term::new_int(n))); + let [x, y] = ["x", "y"].map(|s| p.add(Term::new_var(s, int_sort.clone()))); let inner = p.add(Term::Op(Operator::Add, vec![x, y])); p.add(Term::Let( BindingList(vec![("x".into(), one), ("y".into(), two)]), @@ -357,18 +365,18 @@ fn test_let_terms() { #[test] fn test_lambda_terms() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let int_sort = p.add(Term::Sort(Sort::Int)); let cases = [ ("(lambda ((x Int)) x)", { - let x = p.add(Term::var("x", int_sort.clone())); + let x = p.add(Term::new_var("x", int_sort.clone())); p.add(Term::Lambda( BindingList(vec![("x".into(), int_sort.clone())]), x, )) }), ("(lambda ((x Int) (y Int)) (+ x y))", { - let [x, y] = ["x", "y"].map(|s| p.add(Term::var(s, int_sort.clone()))); + let [x, y] = ["x", "y"].map(|s| p.add(Term::new_var(s, int_sort.clone()))); let inner = p.add(Term::Op(Operator::Add, vec![x, y])); p.add(Term::Lambda( BindingList(vec![("x".into(), int_sort.clone()), ("y".into(), int_sort)]), @@ -389,8 +397,8 @@ fn test_lambda_terms() { #[test] fn test_annotated_terms() { - let mut p = TermPool::new(); - let [zero, two, three] = [0, 2, 3].map(|n| p.add(Term::integer(n))); + let mut p = PrimitivePool::new(); + let [zero, two, three] = [0, 2, 3].map(|n| p.add(Term::new_int(n))); let cases = [ ("(! 0 :named foo)", zero.clone()), ("(! (! 0 :named foo) :named bar)", zero.clone()), @@ -420,7 +428,7 @@ fn test_annotated_terms() { #[test] fn test_declare_fun() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); parse_terms( &mut p, @@ -437,12 +445,12 @@ fn test_declare_fun() { let [got] = parse_terms(&mut p, "(declare-fun x () Real)", ["x"]); let real_sort = p.add(Term::Sort(Sort::Real)); - assert_eq!(p.add(Term::var("x", real_sort)), got); + assert_eq!(p.add(Term::new_var("x", real_sort)), got); } #[test] fn test_declare_sort() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); parse_terms( &mut p, @@ -462,12 +470,12 @@ fn test_declare_sort() { ["x"], ); let expected_sort = p.add(Term::Sort(Sort::Atom("T".to_owned(), Vec::new()))); - assert_eq!(p.add(Term::var("x", expected_sort)), got); + assert_eq!(p.add(Term::new_var("x", expected_sort)), got); } #[test] fn test_define_fun() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let [got] = parse_terms( &mut p, "(define-fun add ((a Int) (b Int)) Int (+ a b))", @@ -488,9 +496,36 @@ fn test_define_fun() { assert_eq!(expected, got); } +#[test] +fn test_assume() { + let mut p = PrimitivePool::new(); + let input = " + (assume h1 true) + (assume h2 (or true false) :ignore \"extra\" :attributes) + "; + let proof = parse_proof(&mut p, input); + assert_eq!(proof.commands.len(), 2); + + assert_eq!( + &proof.commands[0], + &ProofCommand::Assume { + id: "h1".into(), + term: p.bool_true(), + } + ); + + assert_eq!( + &proof.commands[1], + &ProofCommand::Assume { + id: "h2".into(), + term: parse_term(&mut p, "(or true false)"), + } + ); +} + #[test] fn test_step() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let input = " (step t1 (cl (= (+ 2 3) (- 1 2))) :rule rule-name) (step t2 (cl) :rule rule-name :premises (t1)) @@ -534,10 +569,14 @@ fn test_step() { rule: "rule-name".into(), premises: Vec::new(), args: { - vec![Term::integer(1), Term::real(2), Term::string("three")] - .into_iter() - .map(|term| ProofArg::Term(p.add(term))) - .collect() + vec![ + Term::new_int(1), + Term::new_real(2), + Term::new_string("three"), + ] + .into_iter() + .map(|term| ProofArg::Term(p.add(term))) + .collect() }, discharge: Vec::new(), }) @@ -552,8 +591,8 @@ fn test_step() { premises: Vec::new(), args: { vec![ - ProofArg::Assign("a".into(), p.add(Term::integer(12))), - ProofArg::Assign("b".into(), p.add(Term::real((314, 100)))), + ProofArg::Assign("a".into(), p.add(Term::new_int(12))), + ProofArg::Assign("b".into(), p.add(Term::new_real((314, 100)))), ProofArg::Assign("c".into(), parse_term(&mut p, "(* 6 7)")), ] }, @@ -568,7 +607,7 @@ fn test_step() { clause: Vec::new(), rule: "rule-name".into(), premises: vec![(0, 0), (0, 1), (0, 2)], - args: vec![ProofArg::Term(p.add(Term::integer(42)))], + args: vec![ProofArg::Term(p.add(Term::new_int(42)))], discharge: Vec::new(), }) ); @@ -576,7 +615,7 @@ fn test_step() { #[test] fn test_premises_in_subproofs() { - let mut p = TermPool::new(); + let mut p = PrimitivePool::new(); let input = " (assume h1 true) (assume h2 true) diff --git a/carcara/src/utils.rs b/carcara/src/utils.rs index 10f0a2ae..dc662a7e 100644 --- a/carcara/src/utils.rs +++ b/carcara/src/utils.rs @@ -1,5 +1,5 @@ use crate::ast::{BindingList, Quantifier, Rc, Term}; -use ahash::{AHashMap, AHashSet, AHasher}; +use indexmap::{IndexMap, IndexSet}; use std::{ borrow::Borrow, fmt, @@ -25,7 +25,7 @@ pub fn is_symbol_character(ch: char) -> bool { /// An iterator that removes duplicate elements from `iter`. This will yield the elements in /// `iter` in order, skipping elements that have already been seen before. pub struct Dedup { - seen: AHashSet, + seen: IndexSet, iter: I, } @@ -59,7 +59,7 @@ impl> DedupIterator for I { where Self: Sized, { - Dedup { seen: AHashSet::new(), iter: self } + Dedup { seen: IndexSet::new(), iter: self } } } @@ -84,7 +84,7 @@ impl Hash for HashCache { impl HashCache { pub fn new(value: T) -> Self { - let mut hasher = AHasher::default(); + let mut hasher = std::collections::hash_map::DefaultHasher::default(); value.hash(&mut hasher); Self { hash: hasher.finish(), value } } @@ -101,26 +101,23 @@ impl AsRef for HashCache { } #[derive(Debug)] -pub struct SymbolTable { - scopes: Vec>, +pub struct HashMapStack { + scopes: Vec>, } -impl SymbolTable { +impl HashMapStack { pub fn new() -> Self { - Self { scopes: vec![AHashMap::new()] } + Self { scopes: vec![IndexMap::new()] } } pub fn push_scope(&mut self) { - self.scopes.push(AHashMap::new()); + self.scopes.push(IndexMap::new()); } pub fn pop_scope(&mut self) { match self.scopes.len() { 0 => unreachable!(), - 1 => { - log::error!("cannot pop last scope in symbol table"); - panic!(); - } + 1 => panic!("trying to pop last scope in `HashMapStack`"), _ => { self.scopes.pop().unwrap(); } @@ -128,7 +125,7 @@ impl SymbolTable { } } -impl SymbolTable { +impl HashMapStack { pub fn get(&self, key: &Q) -> Option<&V> where K: Borrow, @@ -161,7 +158,7 @@ impl SymbolTable { } } -impl Default for SymbolTable { +impl Default for HashMapStack { fn default() -> Self { Self::new() } diff --git a/carcara/tests/test_example_files.rs b/carcara/tests/test_example_files.rs index fff74c62..17385630 100644 --- a/carcara/tests/test_example_files.rs +++ b/carcara/tests/test_example_files.rs @@ -4,37 +4,66 @@ use std::{ path::{Path, PathBuf}, }; +fn run_parallel_checker_test( + problem_path: &Path, + proof_path: &Path, + num_threads: usize, +) -> CarcaraResult<()> { + use checker::Config; + use std::sync::Arc; + + let (prelude, proof, pool) = parser::parse_instance( + io::BufReader::new(fs::File::open(problem_path)?), + io::BufReader::new(fs::File::open(proof_path)?), + parser::Config::new(), + )?; + + let (scheduler, schedule_context_usage) = checker::Scheduler::new(num_threads, &proof); + let mut checker = checker::ParallelProofChecker::new( + Arc::new(pool), + Config::new(), + &prelude, + &schedule_context_usage, + 128 * 1024 * 1024, + ); + checker.check(&proof, &scheduler)?; + Ok(()) +} + fn run_test(problem_path: &Path, proof_path: &Path) -> CarcaraResult<()> { use checker::Config; let (prelude, proof, mut pool) = parser::parse_instance( io::BufReader::new(fs::File::open(problem_path)?), io::BufReader::new(fs::File::open(proof_path)?), - true, - false, - false, + parser::Config::new(), )?; // First, we check the proof normally - checker::ProofChecker::new(&mut pool, Config::new(), prelude.clone()).check(&proof)?; + checker::ProofChecker::new(&mut pool, Config::new(), &prelude).check(&proof)?; // Then, we check it while elaborating the proof - let mut checker = checker::ProofChecker::new(&mut pool, Config::new(), prelude.clone()); + let mut checker = checker::ProofChecker::new(&mut pool, Config::new(), &prelude); let (_, elaborated) = checker.check_and_elaborate(proof)?; // After that, we check the elaborated proof normally, to make sure it is valid - checker::ProofChecker::new(&mut pool, Config::new().strict(true), prelude.clone()) + checker::ProofChecker::new(&mut pool, Config::new().strict(true), &prelude) .check(&elaborated)?; // Finally, we elaborate the already elaborated proof, to make sure the elaboration step is // idempotent - let mut checker = checker::ProofChecker::new(&mut pool, Config::new().strict(true), prelude); + let mut checker = checker::ProofChecker::new(&mut pool, Config::new().strict(true), &prelude); let (_, elaborated_twice) = checker.check_and_elaborate(elaborated.clone())?; assert!( elaborated.commands == elaborated_twice.commands, "elaboration was not idempotent!" ); + // We also test the parallel checker, with different values for the number of threads + run_parallel_checker_test(problem_path, proof_path, 1)?; + run_parallel_checker_test(problem_path, proof_path, 4)?; + run_parallel_checker_test(problem_path, proof_path, 16)?; + Ok(()) } diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 33bfb204..27f22106 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,8 +1,8 @@ [package] name = "carcara-cli" -version = "1.0.0" +version = "1.1.0" edition = "2021" -rust-version = "1.67" +rust-version = "1.72" license = "Apache-2.0" [[bin]] @@ -11,11 +11,9 @@ path = "src/main.rs" [dependencies] carcara = { path = "../carcara" } -ahash = "0.8.3" -clap = { version = "3.2.23", features = ["derive"] } -const_format = "0.2.30" +clap = { version = "3.2.25", features = ["derive"] } +const_format = "0.2.31" crossbeam-queue = "0.3.8" -log = { version = "0.4.17", features = ["std"] } +log = { version = "0.4.20", features = ["std"] } ansi_term = "0.12" git-version = "0.3.5" -atty = "0.2.14" diff --git a/cli/src/benchmarking.rs b/cli/src/benchmarking.rs index a1efa273..82dbdf14 100644 --- a/cli/src/benchmarking.rs +++ b/cli/src/benchmarking.rs @@ -1,8 +1,6 @@ use carcara::{ benchmarking::{CollectResults, CsvBenchmarkResults, RunMeasurement}, - checker, - parser::parse_instance, - CarcaraOptions, + checker, parser, CarcaraOptions, }; use crossbeam_queue::ArrayQueue; use std::{ @@ -20,74 +18,74 @@ struct JobDescriptor<'a> { run_index: usize, } -fn run_job( +fn run_job( results: &mut T, job: JobDescriptor, options: &CarcaraOptions, elaborate: bool, ) -> Result { let proof_file_name = job.proof_file.to_str().unwrap(); + let mut checker_stats = checker::CheckerStatistics { + file_name: proof_file_name, + elaboration_time: Duration::ZERO, + polyeq_time: Duration::ZERO, + assume_time: Duration::ZERO, + assume_core_time: Duration::ZERO, + results: std::mem::take(results), + }; let total = Instant::now(); let parsing = Instant::now(); - let (prelude, proof, mut pool) = parse_instance( + let config = parser::Config { + apply_function_defs: options.apply_function_defs, + expand_lets: options.expand_lets, + allow_int_real_subtyping: options.allow_int_real_subtyping, + }; + let (prelude, proof, mut pool) = parser::parse_instance( BufReader::new(File::open(job.problem_file)?), BufReader::new(File::open(job.proof_file)?), - options.apply_function_defs, - options.expand_lets, - options.allow_int_real_subtyping, + config, )?; let parsing = parsing.elapsed(); - let mut elaboration = Duration::ZERO; - let mut deep_eq = Duration::ZERO; - let mut assume = Duration::ZERO; - let mut assume_core = Duration::ZERO; - let config = checker::Config::new() .strict(options.strict) - .skip_unknown_rules(options.skip_unknown_rules) - .lia_via_cvc5(options.lia_via_cvc5) - .statistics(checker::CheckerStatistics { - file_name: proof_file_name, - elaboration_time: &mut elaboration, - deep_eq_time: &mut deep_eq, - assume_time: &mut assume, - assume_core_time: &mut assume_core, - results, - }); - let mut checker = checker::ProofChecker::new(&mut pool, config, prelude); + .ignore_unknown_rules(options.ignore_unknown_rules) + .lia_options(options.lia_options.clone()); + let mut checker = checker::ProofChecker::new(&mut pool, config, &prelude); let checking = Instant::now(); let checking_result = if elaborate { checker - .check_and_elaborate(proof) + .check_and_elaborate_with_stats(proof, &mut checker_stats) .map(|(is_holey, _)| is_holey) } else { - checker.check(&proof) + checker.check_with_stats(&proof, &mut checker_stats) }; let checking = checking.elapsed(); let total = total.elapsed(); - results.add_run_measurement( + checker_stats.results.add_run_measurement( &(proof_file_name.to_string(), job.run_index), RunMeasurement { parsing, checking, - elaboration, + elaboration: checker_stats.elaboration_time, + scheduling: Duration::ZERO, total, - deep_eq, - assume, - assume_core, + polyeq: checker_stats.polyeq_time, + assume: checker_stats.assume_time, + assume_core: checker_stats.assume_core_time, }, ); + *results = checker_stats.results; checking_result } -fn worker_thread( +fn worker_thread( jobs_queue: &ArrayQueue, options: &CarcaraOptions, elaborate: bool, @@ -111,7 +109,7 @@ fn worker_thread( pub fn run_benchmark( instances: &[(PathBuf, PathBuf)], num_runs: usize, - num_threads: usize, + num_jobs: usize, options: &CarcaraOptions, elaborate: bool, ) -> T { @@ -135,7 +133,7 @@ pub fn run_benchmark( // We of course need to `collect` here to ensure we spawn all threads before starting to // `join` them #[allow(clippy::needless_collect)] - let workers: Vec<_> = (0..num_threads) + let workers: Vec<_> = (0..num_jobs) .map(|_| { thread::Builder::new() .stack_size(STACK_SIZE) @@ -155,14 +153,14 @@ pub fn run_benchmark( pub fn run_csv_benchmark( instances: &[(PathBuf, PathBuf)], num_runs: usize, - num_threads: usize, + num_jobs: usize, options: &CarcaraOptions, elaborate: bool, runs_dest: &mut dyn io::Write, by_rule_dest: &mut dyn io::Write, ) -> io::Result<()> { let result: CsvBenchmarkResults = - run_benchmark(instances, num_runs, num_threads, options, elaborate); + run_benchmark(instances, num_runs, num_jobs, options, elaborate); println!( "{} errors encountered during benchmark", result.num_errors() diff --git a/cli/src/error.rs b/cli/src/error.rs index 53073039..ab70fec4 100644 --- a/cli/src/error.rs +++ b/cli/src/error.rs @@ -4,6 +4,7 @@ use std::{fmt, io, path::PathBuf}; pub enum CliError { CarcaraError(carcara::Error), CantInferProblemFile(PathBuf), + InvalidSliceId(String), BothFilesStdin, } @@ -29,6 +30,7 @@ impl fmt::Display for CliError { write!(f, "can't infer problem file: {}", p.display()) } CliError::BothFilesStdin => write!(f, "problem and proof files can't both be `-`"), + CliError::InvalidSliceId(id) => write!(f, "invalid id for slice: {}", id), } } } diff --git a/cli/src/main.rs b/cli/src/main.rs index 3549669f..2710b6fb 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -4,9 +4,8 @@ mod logger; mod path_args; use carcara::{ - ast::print_proof, - benchmarking::{Metrics, OnlineBenchmarkResults}, - check, check_and_elaborate, parser, CarcaraOptions, + ast::print_proof, benchmarking::OnlineBenchmarkResults, check, check_and_elaborate, + check_parallel, parser, CarcaraOptions, LiaGenericOptions, }; use clap::{AppSettings, ArgEnum, Args, Parser, Subcommand}; use const_format::{formatcp, str_index}; @@ -15,7 +14,7 @@ use git_version::git_version; use path_args::{get_instances_from_paths, infer_problem_path}; use std::{ fs::File, - io::{self, BufRead}, + io::{self, BufRead, IsTerminal}, path::Path, }; @@ -70,6 +69,9 @@ enum Command { /// Checks a series of proof files and records performance statistics. Bench(BenchCommandOptions), + + /// Given a step, takes a slice of a proof consisting of all its transitive premises. + Slice(SliceCommandOption), } #[derive(Args)] @@ -82,6 +84,20 @@ struct Input { problem_file: Option, } +#[derive(Args)] +struct StatsOptions { + /// Enables the gathering of performance statistics + #[clap(long)] + stats: bool, +} + +#[derive(Args)] +struct StackOptions { + /// Defines the thread stack size for each check worker (does not include the main thread stack size, which should be set manually). + #[clap(long, default_value = "0")] + stack_size: usize, +} + #[derive(Args, Clone, Copy)] struct ParsingOptions { /// Expand function definitions introduced by `define-fun`s in the SMT problem. If this flag is @@ -101,18 +117,36 @@ struct ParsingOptions { allow_int_real_subtyping: bool, } -#[derive(Args, Clone, Copy)] +#[derive(Args, Clone)] struct CheckingOptions { /// Enables the strict checking of certain rules. #[clap(short, long)] strict: bool, - /// Skips rules that are not known by the checker. - #[clap(long)] + /// Allow steps with rules that are not known by the checker, and consider them as holes. + #[clap(short, long)] + ignore_unknown_rules: bool, + + // Note: the `--skip-unknown-rules` flag has been deprecated in favor of `--ignore-unknown-rules` + #[clap(long, conflicts_with("ignore-unknown-rules"), hide = true)] skip_unknown_rules: bool, - /// Check `lia_generic` steps by calling into cvc5. + /// Check `lia_generic` steps using the provided solver. #[clap(long)] + lia_solver: Option, + + /// The arguments to pass to the `lia_generic` solver. This should be a single string where + /// multiple arguments are separated by spaces. + #[clap( + long, + requires = "lia-solver", + allow_hyphen_values = true, + default_value = "--tlimit=10000 --lang=smt2 --proof-format-mode=alethe --proof-granularity=theory-rewrite --proof-alethe-res-pivots" + )] + lia_solver_args: String, + + /// Check `lia_generic` steps by calling into cvc5 (deprecated). + #[clap(long, conflicts_with("lia-solver"))] lia_via_cvc5: bool, } @@ -131,17 +165,29 @@ fn build_carcara_options( }: ParsingOptions, CheckingOptions { strict, + ignore_unknown_rules, skip_unknown_rules, + lia_solver, lia_via_cvc5, + lia_solver_args, }: CheckingOptions, + StatsOptions { stats }: StatsOptions, ) -> CarcaraOptions { + // If no solver is provided by the `--lia-solver` option, *and* the `--lia-via-cvc5` option was + // passed, we default to cvc5 as a solver + let solver = lia_solver.or_else(|| lia_via_cvc5.then(|| "cvc5".into())); + let lia_options = solver.map(|solver| LiaGenericOptions { + solver: solver.into(), + arguments: lia_solver_args.split_whitespace().map(Into::into).collect(), + }); CarcaraOptions { apply_function_defs, expand_lets: expand_let_bindings, allow_int_real_subtyping, - lia_via_cvc5, + lia_options, strict, - skip_unknown_rules, + ignore_unknown_rules: ignore_unknown_rules || skip_unknown_rules, + stats, } } @@ -167,6 +213,26 @@ struct CheckCommandOptions { #[clap(flatten)] checking: CheckingOptions, + + /// Defines the number of cores for proof checking. + #[clap(short = 'u', long, required = false, default_value = "1", validator = |s: &str| -> Result<(), String> { + if let Ok(n) = s.to_string().parse() as Result { + if n < 1 { + Err(format!("The threads number can't be {n}.")) + } else { + Ok(()) + } + } else { + Err(String::from("Not a number.")) + } + })] + num_threads: usize, + + #[clap(flatten)] + stats: StatsOptions, + + #[clap(flatten)] + stack: StackOptions, } #[derive(Args)] @@ -182,6 +248,9 @@ struct ElaborateCommandOptions { #[clap(flatten)] printing: PrintingOptions, + + #[clap(flatten)] + stats: StatsOptions, } #[derive(Args)] @@ -200,9 +269,9 @@ struct BenchCommandOptions { #[clap(short, long, default_value_t = 1)] num_runs: usize, - /// Number of threads to use when running the benchmark. + /// Number of jobs to run simultaneously when running the benchmark. #[clap(short = 'j', long, default_value_t = 1)] - num_threads: usize, + num_jobs: usize, /// Show benchmark results sorted by total time taken, instead of by average time taken. #[clap(short = 't', long)] @@ -213,11 +282,29 @@ struct BenchCommandOptions { dump_to_csv: bool, /// The proof files on which the benchmark will be run. If a directory is passed, the checker - /// will recursively find all '.proof' files in the directory. The problem files will be + /// will recursively find all proof files in the directory. The problem files will be /// inferred from the proof files. files: Vec, } +#[derive(Args)] +struct SliceCommandOption { + #[clap(flatten)] + input: Input, + + #[clap(flatten)] + parsing: ParsingOptions, + + #[clap(flatten)] + printing: PrintingOptions, + + #[clap(long)] + from: String, + + #[clap(long, short = 'd')] + max_distance: Option, +} + #[derive(ArgEnum, Clone)] enum LogLevel { Off, @@ -239,9 +326,26 @@ impl From for log::LevelFilter { fn main() { let cli = Cli::parse(); - let colors_enabled = !cli.no_color && atty::is(atty::Stream::Stderr); + let colors_enabled = !cli.no_color && std::io::stderr().is_terminal(); logger::init(cli.log_level.into(), colors_enabled); + if let Command::Check(CheckCommandOptions { checking, .. }) + | Command::Elaborate(ElaborateCommandOptions { checking, .. }) + | Command::Bench(BenchCommandOptions { checking, .. }) = &cli.command + { + if checking.skip_unknown_rules { + log::warn!( + "the `--skip-unknown-rules` option is deprecated, please use \ + `--ignore-unknown-rules` instead" + ) + } + if checking.lia_via_cvc5 { + log::warn!( + "the `--lia-via-cvc5` option is deprecated, please use `--lia-solver cvc5` instead" + ) + } + } + let result = match cli.command { Command::Parse(options) => parse_command(options), Command::Check(options) => { @@ -258,6 +362,7 @@ fn main() { } Command::Elaborate(options) => elaborate_command(options), Command::Bench(options) => bench_command(options), + Command::Slice(options) => slice_command(options), }; if let Err(e) = result { log::error!("{}", e); @@ -287,9 +392,11 @@ fn parse_command(options: ParseCommandOptions) -> CliResult<()> { let (_, proof, _) = parser::parse_instance( problem, proof, - options.parsing.apply_function_defs, - options.parsing.expand_let_bindings, - options.parsing.allow_int_real_subtyping, + parser::Config { + apply_function_defs: options.parsing.apply_function_defs, + expand_lets: options.parsing.expand_let_bindings, + allow_int_real_subtyping: options.parsing.allow_int_real_subtyping, + }, ) .map_err(carcara::Error::from)?; print_proof(&proof.commands, options.printing.use_sharing)?; @@ -298,11 +405,18 @@ fn parse_command(options: ParseCommandOptions) -> CliResult<()> { fn check_command(options: CheckCommandOptions) -> CliResult { let (problem, proof) = get_instance(&options.input)?; - check( - problem, - proof, - build_carcara_options(options.parsing, options.checking), - ) + let carc_options = build_carcara_options(options.parsing, options.checking, options.stats); + if options.num_threads == 1 { + check(problem, proof, carc_options) + } else { + check_parallel( + problem, + proof, + carc_options, + options.num_threads, + options.stack.stack_size, + ) + } .map_err(Into::into) } @@ -312,7 +426,7 @@ fn elaborate_command(options: ElaborateCommandOptions) -> CliResult<()> { let (_, elaborated) = check_and_elaborate( problem, proof, - build_carcara_options(options.parsing, options.checking), + build_carcara_options(options.parsing, options.checking, options.stats), )?; print_proof(&elaborated.commands, options.printing.use_sharing)?; Ok(()) @@ -331,12 +445,17 @@ fn bench_command(options: BenchCommandOptions) -> CliResult<()> { options.num_runs ); + let carc_options = build_carcara_options( + options.parsing, + options.checking, + StatsOptions { stats: false }, + ); if options.dump_to_csv { benchmarking::run_csv_benchmark( &instances, options.num_runs, - options.num_threads, - &build_carcara_options(options.parsing, options.checking), + options.num_jobs, + &carc_options, options.elaborate, &mut File::create("runs.csv")?, &mut File::create("by-rule.csv")?, @@ -347,8 +466,8 @@ fn bench_command(options: BenchCommandOptions) -> CliResult<()> { let results: OnlineBenchmarkResults = benchmarking::run_benchmark( &instances, options.num_runs, - options.num_threads, - &build_carcara_options(options.parsing, options.checking), + options.num_jobs, + &carc_options, options.elaborate, ); if results.is_empty() { @@ -363,119 +482,29 @@ fn bench_command(options: BenchCommandOptions) -> CliResult<()> { } else { println!("valid"); } - print_benchmark_results(results, options.sort_by_total) + results.print(options.sort_by_total); + Ok(()) } -fn print_benchmark_results(results: OnlineBenchmarkResults, sort_by_total: bool) -> CliResult<()> { - let [parsing, checking, elaborating, accounted_for, total] = [ - results.parsing(), - results.checking(), - results.elaborating(), - results.total_accounted_for(), - results.total(), - ] - .map(|m| { - if sort_by_total { - format!("{:#}", m) - } else { - format!("{}", m) - } - }); - - println!("parsing: {}", parsing); - println!("checking: {}", checking); - if !elaborating.is_empty() { - println!("elaborating: {}", elaborating); - } - println!( - "on assume: {} ({:.02}% of checking time)", - results.assume_time, - 100.0 * results.assume_time.mean().as_secs_f64() / results.checking().mean().as_secs_f64(), - ); - println!("on assume (core): {}", results.assume_core_time); - println!("assume ratio: {}", results.assume_time_ratio); - println!( - "on deep equality: {} ({:.02}% of checking time)", - results.deep_eq_time, - 100.0 * results.deep_eq_time.mean().as_secs_f64() / results.checking().mean().as_secs_f64(), - ); - println!("deep equality ratio: {}", results.deep_eq_time_ratio); - println!("total accounted for: {}", accounted_for); - println!("total: {}", total); - - let data_by_rule = results.step_time_by_rule(); - let mut data_by_rule: Vec<_> = data_by_rule.iter().collect(); - data_by_rule.sort_by_key(|(_, m)| if sort_by_total { m.total() } else { m.mean() }); - - println!("by rule:"); - for (rule, data) in data_by_rule { - print!(" {: <18}", rule); - if sort_by_total { - println!("{:#}", data) - } else { - println!("{}", data) - } - } - - println!("worst cases:"); - let worst_step = results.step_time().max(); - println!(" step: {} ({:?})", worst_step.0, worst_step.1); - - let worst_file_parsing = results.parsing().max(); - println!( - " file (parsing): {} ({:?})", - worst_file_parsing.0 .0, worst_file_parsing.1 - ); - - let worst_file_checking = results.checking().max(); - println!( - " file (checking): {} ({:?})", - worst_file_checking.0 .0, worst_file_checking.1 - ); - - let worst_file_assume = results.assume_time_ratio.max(); - println!( - " file (assume): {} ({:.04}%)", - worst_file_assume.0 .0, - worst_file_assume.1 * 100.0 - ); - - let worst_file_deep_eq = results.deep_eq_time_ratio.max(); - println!( - " file (deep_eq): {} ({:.04}%)", - worst_file_deep_eq.0 .0, - worst_file_deep_eq.1 * 100.0 - ); - - let worst_file_total = results.total().max(); - println!( - " file overall: {} ({:?})", - worst_file_total.0 .0, worst_file_total.1 - ); - - let num_hard_assumes = results.num_assumes - results.num_easy_assumes; - let percent_easy = (results.num_easy_assumes as f64) * 100.0 / (results.num_assumes as f64); - let percent_hard = (num_hard_assumes as f64) * 100.0 / (results.num_assumes as f64); - println!(" number of assumes: {}", results.num_assumes); - println!( - " (easy): {} ({:.02}%)", - results.num_easy_assumes, percent_easy - ); - println!( - " (hard): {} ({:.02}%)", - num_hard_assumes, percent_hard - ); - - let depths = results.deep_eq_depths; - if !depths.is_empty() { - println!(" max deep equality depth: {}", depths.max().1); - println!(" total deep equality depth: {}", depths.total()); - println!(" number of deep equalities: {}", depths.count()); - println!(" mean depth: {:.4}", depths.mean()); - println!( - "standard deviation of depth: {:.4}", - depths.standard_deviation() - ); - } +fn slice_command(options: SliceCommandOption) -> CliResult<()> { + let (problem, proof) = get_instance(&options.input)?; + let config = parser::Config { + apply_function_defs: options.parsing.apply_function_defs, + expand_lets: options.parsing.expand_let_bindings, + allow_int_real_subtyping: options.parsing.allow_int_real_subtyping, + }; + let (_, proof, _) = + parser::parse_instance(problem, proof, config).map_err(carcara::Error::from)?; + + let source_index = proof + .commands + .iter() + .position(|c| c.id() == options.from) + .ok_or_else(|| CliError::InvalidSliceId(options.from.to_owned()))?; + + let diff = + carcara::elaborator::slice_proof(&proof.commands, source_index, options.max_distance); + let slice = carcara::elaborator::apply_diff(diff, proof.commands); + print_proof(&slice, options.printing.use_sharing)?; Ok(()) } diff --git a/cli/src/path_args.rs b/cli/src/path_args.rs index 3e6208b7..b9fa2f24 100644 --- a/cli/src/path_args.rs +++ b/cli/src/path_args.rs @@ -4,6 +4,7 @@ use crate::error::CliError; use std::{ffi::OsStr, fs, path::PathBuf}; const SMT_FILE_EXTENSIONS: [&str; 3] = ["smt", "smt2", "smt_in"]; +const ALETHE_FILE_EXTENSIONS: [&str; 2] = ["alethe", "proof"]; pub fn infer_problem_path(proof_path: impl Into) -> Result { fn inner(mut path: PathBuf) -> Option { @@ -22,7 +23,11 @@ fn get_instances_from_dir( ) -> Result<(), CliError> { let file_type = fs::metadata(&path)?.file_type(); if file_type.is_file() { - if path.extension() == Some(OsStr::new("proof")) { + let is_proof_file = path + .extension() + .and_then(OsStr::to_str) + .is_some_and(|ext| ALETHE_FILE_EXTENSIONS.contains(&ext)); + if is_proof_file { let problem_file = infer_problem_path(&path)?; acc.push((problem_file, path)) } diff --git a/scripts/generate-benchmarks.sh b/scripts/generate-benchmarks.sh index 2d64707f..5feed2a2 100755 --- a/scripts/generate-benchmarks.sh +++ b/scripts/generate-benchmarks.sh @@ -88,7 +88,7 @@ find $benchmark_dir -name '*.smt2' | xargs -P $num_jobs -n 1 bash -c 'scripts/so if [ -n "clean_flag" ]; then echo "cleaning up..." for f in $(find $benchmark_dir -name '*.smt2'); do - if [ ! -f $f.proof ]; then + if [ ! -f $f.alethe ]; then rm -f $f fi done diff --git a/scripts/solve.sh b/scripts/solve.sh index 74ca30e2..d71a0f12 100755 --- a/scripts/solve.sh +++ b/scripts/solve.sh @@ -1,13 +1,13 @@ #!/bin/bash timeout $timeout $VERIT $1 \ - --proof-file-from-input --proof-with-sharing \ + --proof=$1.alethe --proof-with-sharing \ --proof-prune --proof-merge &> /dev/null # If a complete proof could not be generated, we delete it -if [ -f $1.proof ]; then - if ! grep -q -F '(cl)' $1.proof; then - rm $1.proof +if [ -f $1.alethe ]; then + if ! grep -q -F '(cl)' $1.alethe; then + rm $1.alethe exit fi fi diff --git a/test-generator/Cargo.toml b/test-generator/Cargo.toml index 99c009d7..af88c421 100644 --- a/test-generator/Cargo.toml +++ b/test-generator/Cargo.toml @@ -2,7 +2,7 @@ name = "test-generator" version = "0.1.0" edition = "2021" -rust-version = "1.67" +rust-version = "1.72" license = "Apache-2.0" [lib] diff --git a/test-generator/src/lib.rs b/test-generator/src/lib.rs index 4bcf10e9..5cb22584 100644 --- a/test-generator/src/lib.rs +++ b/test-generator/src/lib.rs @@ -54,11 +54,11 @@ pub fn from_dir(args: TokenStream, input: TokenStream) -> TokenStream { for entry in walkdir::WalkDir::new(&arg) { let Ok(entry) = entry else { continue }; - if entry.file_type().is_file() && entry.path().extension() == Some(OsStr::new("proof")) { + if entry.file_type().is_file() && entry.path().extension() == Some(OsStr::new("alethe")) { let path = entry.path().to_str().unwrap(); let new_ident = { let path = path.strip_prefix(&arg).unwrap().strip_prefix('/').unwrap(); - let path = path.strip_suffix(".proof").unwrap(); + let path = path.strip_suffix(".alethe").unwrap(); let path = path.replace(|c: char| !c.is_ascii_alphanumeric() && c != '_', "_"); syn::Ident::new(&format!("{}_{}", func_ident, path), func_ident.span()) };