diff --git a/carcara/src/ast/pool/advanced.rs b/carcara/src/ast/pool/advanced.rs index d50d5b2d..6eb8b9cf 100644 --- a/carcara/src/ast/pool/advanced.rs +++ b/carcara/src/ast/pool/advanced.rs @@ -47,23 +47,14 @@ impl TermPool for ContextPool { } fn add(&mut self, term: Term) -> Rc { - use std::collections::hash_map::Entry; - // If the global pool has the term if let Some(entry) = self.global_pool.terms.get(&term) { - entry.clone() - } else { - let mut ctx_guard = self.storage.write().unwrap(); - match ctx_guard.terms.entry(term) { - Entry::Occupied(occupied_entry) => occupied_entry.get().clone(), - Entry::Vacant(vacant_entry) => { - let term = vacant_entry.key().clone(); - let t = vacant_entry.insert(Rc::new(term)).clone(); - ctx_guard.compute_sort(&t); - t - } - } + return entry.clone(); } + let mut ctx_guard = self.storage.write().unwrap(); + let term = ctx_guard.terms.add(term); + ctx_guard.compute_sort(&term); + term } fn sort(&self, term: &Rc) -> Rc { @@ -125,8 +116,6 @@ impl TermPool for LocalPool { } fn add(&mut self, term: Term) -> Rc { - use std::collections::hash_map::Entry; - // If there is a constant pool and has the term if let Some(entry) = self.ctx_pool.global_pool.terms.get(&term) { entry.clone() @@ -135,15 +124,7 @@ impl TermPool for LocalPool { else if let Some(entry) = self.ctx_pool.storage.read().unwrap().terms.get(&term) { entry.clone() } else { - match self.storage.terms.entry(term) { - Entry::Occupied(occupied_entry) => occupied_entry.get().clone(), - Entry::Vacant(vacant_entry) => { - let term = vacant_entry.key().clone(); - let t = vacant_entry.insert(Rc::new(term)).clone(); - self.storage.compute_sort(&t); - t - } - } + self.storage.add(term) } } diff --git a/carcara/src/ast/pool/mod.rs b/carcara/src/ast/pool/mod.rs index e2f42948..d36a74af 100644 --- a/carcara/src/ast/pool/mod.rs +++ b/carcara/src/ast/pool/mod.rs @@ -1,10 +1,12 @@ //! This module implements `TermPool`, a structure that stores terms and implements hash consing. pub mod advanced; +mod storage; use super::{Rc, Sort, Term}; use crate::ast::Constant; use ahash::{AHashMap, AHashSet}; +use storage::Storage; pub trait TermPool { /// Returns the term corresponding to the boolean constant `true`. @@ -51,7 +53,7 @@ pub trait TermPool { /// [`PrimitivePool::sort`]) or its free variables (see [`PrimitivePool::free_vars`]). pub struct PrimitivePool { /// A map of the terms in the pool. - pub(crate) terms: AHashMap>, + pub(crate) terms: Storage, pub(crate) free_vars_cache: AHashMap, AHashSet>>, pub(crate) sorts_cache: AHashMap, Rc>, pub(crate) bool_true: Rc, @@ -68,12 +70,12 @@ impl PrimitivePool { /// Constructs a new `TermPool`. This new pool will already contain the boolean constants `true` /// and `false`, as well as the `Bool` sort. pub fn new() -> Self { - let mut terms = AHashMap::new(); + let mut terms = Storage::new(); let mut sorts_cache = AHashMap::new(); - let bool_sort = Self::add_term_to_map(&mut terms, Term::Sort(Sort::Bool)); + let bool_sort = terms.add(Term::Sort(Sort::Bool)); - let [bool_true, bool_false] = ["true", "false"] - .map(|b| Self::add_term_to_map(&mut terms, Term::new_var(b, bool_sort.clone()))); + let [bool_true, bool_false] = + ["true", "false"].map(|b| terms.add(Term::new_var(b, bool_sort.clone()))); sorts_cache.insert(bool_false.clone(), bool_sort.clone()); sorts_cache.insert(bool_true.clone(), bool_sort.clone()); @@ -88,18 +90,6 @@ impl PrimitivePool { } } - fn add_term_to_map(terms_map: &mut AHashMap>, term: Term) -> Rc { - use std::collections::hash_map::Entry; - - match terms_map.entry(term) { - Entry::Occupied(occupied_entry) => occupied_entry.get().clone(), - Entry::Vacant(vacant_entry) => { - let term = vacant_entry.key().clone(); - vacant_entry.insert(Rc::new(term)).clone() - } - } - } - /// Computes the sort of a term and adds it to the sort cache. fn compute_sort(&mut self, term: &Rc) -> Rc { use super::Operator; @@ -164,8 +154,8 @@ impl PrimitivePool { Sort::Function(result) } }; - let sorted_term = Self::add_term_to_map(&mut self.terms, Term::Sort(result)); - self.sorts_cache.insert(term.clone(), sorted_term); + let sort = self.terms.add(Term::Sort(result)); + self.sorts_cache.insert(term.clone(), sort); self.sorts_cache[term].clone() } @@ -174,8 +164,6 @@ impl PrimitivePool { term: Term, prior_pools: [&PrimitivePool; N], ) -> Rc { - use std::collections::hash_map::Entry; - for p in prior_pools { // If this prior pool has the term if let Some(entry) = p.terms.get(&term) { @@ -183,15 +171,7 @@ impl PrimitivePool { } } - match self.terms.entry(term) { - Entry::Occupied(occupied_entry) => occupied_entry.get().clone(), - Entry::Vacant(vacant_entry) => { - let term = vacant_entry.key().clone(); - let term = vacant_entry.insert(Rc::new(term)).clone(); - self.compute_sort(&term); - term - } - } + self.add(term) } fn sort_with_priorities( @@ -283,7 +263,7 @@ impl TermPool for PrimitivePool { } fn add(&mut self, term: Term) -> Rc { - let term = Self::add_term_to_map(&mut self.terms, term); + let term = self.terms.add(term); self.compute_sort(&term); term } diff --git a/carcara/src/ast/pool/storage.rs b/carcara/src/ast/pool/storage.rs new file mode 100644 index 00000000..fb611e51 --- /dev/null +++ b/carcara/src/ast/pool/storage.rs @@ -0,0 +1,66 @@ +//* The behaviour of the term pool could be modeled by a hash map from `Term` to `Rc`, but +//* that would require allocating two copies of each term, one in the key of the hash map, and one +//* inside the `Rc`. Instead, we store a hash set of `Rc`s, combining the key and the value +//* into a single object. We access this hash set using a `&Term`, and if the entry is present, we +//* clone it; otherwise, we allocate a new `Rc`. + +use crate::ast::*; +use std::borrow::Borrow; + +/// Since `ast::Rc` intentionally implements hashing and equality by reference (instead of by +/// value), we cannot safely implement `Borrow` for `Rc`, so we cannot access a +/// `HashSet>` using a `&Term` as a key. To go around that, we use this struct that wraps +/// an `Rc` and that re-implements hashing and equality by value, meaning we can implement +/// `Borrow` for it, and use it as the contents of the hash set instead. +#[derive(Debug, Clone, Eq)] +struct ByValue(Rc); + +impl PartialEq for ByValue { + fn eq(&self, other: &Self) -> bool { + self.0.as_ref() == other.0.as_ref() + } +} + +impl Hash for ByValue { + fn hash(&self, state: &mut H) { + self.0.as_ref().hash(state); + } +} + +impl Borrow for ByValue { + fn borrow(&self) -> &Term { + self.0.as_ref() + } +} + +#[derive(Debug, Clone, Default)] +pub struct Storage(AHashSet); + +impl Storage { + pub fn new() -> Self { + Self::default() + } + + pub fn add(&mut self, term: Term) -> Rc { + // If the `hash_set_entry` feature was stable, this would be much simpler to do using + // `get_or_insert_with` (and would avoid rehashing the term) + match self.0.get(&term) { + Some(t) => t.0.clone(), + None => { + let result = Rc::new(term); + self.0.insert(ByValue(result.clone())); + result + } + } + } + + pub fn get(&self, term: &Term) -> Option<&Rc> { + self.0.get(term).map(|t| &t.0) + } + + // This method is only necessary for the hash consing tests + #[cfg(test)] + pub fn into_vec(self) -> Vec> { + self.0.into_iter().map(|ByValue(t)| t).collect() + } +} diff --git a/carcara/src/parser/tests.rs b/carcara/src/parser/tests.rs index a5ace572..0f493dbc 100644 --- a/carcara/src/parser/tests.rs +++ b/carcara/src/parser/tests.rs @@ -94,9 +94,9 @@ fn test_hash_consing() { .into_iter() .collect::>(); - assert_eq!(pool.terms.len(), expected.len()); - - for got in pool.terms.keys() { + let pool_terms = pool.terms.into_vec(); + assert_eq!(pool_terms.len(), expected.len()); + for got in pool_terms { let formatted: &str = &format!("{}", got); assert!(expected.contains(formatted), "{}", formatted); }