Skip to content

Commit

Permalink
Avoid holding two copies of each term in the term pool
Browse files Browse the repository at this point in the history
Now, instead of storing terms with a `HashMap<Term, Rc<Term>>`, the pool
uses a hash set.
  • Loading branch information
bpandreotti committed Aug 11, 2023
1 parent 271152c commit 4d2a059
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 59 deletions.
31 changes: 6 additions & 25 deletions carcara/src/ast/pool/advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,14 @@ impl TermPool for ContextPool {
}

fn add(&mut self, term: Term) -> Rc<Term> {
use std::collections::hash_map::Entry;

// If the global pool has the term
if let Some(entry) = self.global_pool.terms.get(&term) {
entry.clone()
} else {
let mut ctx_guard = self.storage.write().unwrap();
match ctx_guard.terms.entry(term) {
Entry::Occupied(occupied_entry) => occupied_entry.get().clone(),
Entry::Vacant(vacant_entry) => {
let term = vacant_entry.key().clone();
let t = vacant_entry.insert(Rc::new(term)).clone();
ctx_guard.compute_sort(&t);
t
}
}
return entry.clone();
}
let mut ctx_guard = self.storage.write().unwrap();
let term = ctx_guard.terms.add(term);
ctx_guard.compute_sort(&term);
term
}

fn sort(&self, term: &Rc<Term>) -> Rc<Term> {
Expand Down Expand Up @@ -125,8 +116,6 @@ impl TermPool for LocalPool {
}

fn add(&mut self, term: Term) -> Rc<Term> {
use std::collections::hash_map::Entry;

// If there is a constant pool and has the term
if let Some(entry) = self.ctx_pool.global_pool.terms.get(&term) {
entry.clone()
Expand All @@ -135,15 +124,7 @@ impl TermPool for LocalPool {
else if let Some(entry) = self.ctx_pool.storage.read().unwrap().terms.get(&term) {
entry.clone()
} else {
match self.storage.terms.entry(term) {
Entry::Occupied(occupied_entry) => occupied_entry.get().clone(),
Entry::Vacant(vacant_entry) => {
let term = vacant_entry.key().clone();
let t = vacant_entry.insert(Rc::new(term)).clone();
self.storage.compute_sort(&t);
t
}
}
self.storage.add(term)
}
}

Expand Down
42 changes: 11 additions & 31 deletions carcara/src/ast/pool/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
//! This module implements `TermPool`, a structure that stores terms and implements hash consing.

pub mod advanced;
mod storage;

use super::{Rc, Sort, Term};
use crate::ast::Constant;
use ahash::{AHashMap, AHashSet};
use storage::Storage;

pub trait TermPool {
/// Returns the term corresponding to the boolean constant `true`.
Expand Down Expand Up @@ -51,7 +53,7 @@ pub trait TermPool {
/// [`PrimitivePool::sort`]) or its free variables (see [`PrimitivePool::free_vars`]).
pub struct PrimitivePool {
/// A map of the terms in the pool.
pub(crate) terms: AHashMap<Term, Rc<Term>>,
pub(crate) terms: Storage,
pub(crate) free_vars_cache: AHashMap<Rc<Term>, AHashSet<Rc<Term>>>,
pub(crate) sorts_cache: AHashMap<Rc<Term>, Rc<Term>>,
pub(crate) bool_true: Rc<Term>,
Expand All @@ -68,12 +70,12 @@ impl PrimitivePool {
/// Constructs a new `TermPool`. This new pool will already contain the boolean constants `true`
/// and `false`, as well as the `Bool` sort.
pub fn new() -> Self {
let mut terms = AHashMap::new();
let mut terms = Storage::new();
let mut sorts_cache = AHashMap::new();
let bool_sort = Self::add_term_to_map(&mut terms, Term::Sort(Sort::Bool));
let bool_sort = terms.add(Term::Sort(Sort::Bool));

let [bool_true, bool_false] = ["true", "false"]
.map(|b| Self::add_term_to_map(&mut terms, Term::new_var(b, bool_sort.clone())));
let [bool_true, bool_false] =
["true", "false"].map(|b| terms.add(Term::new_var(b, bool_sort.clone())));

sorts_cache.insert(bool_false.clone(), bool_sort.clone());
sorts_cache.insert(bool_true.clone(), bool_sort.clone());
Expand All @@ -88,18 +90,6 @@ impl PrimitivePool {
}
}

fn add_term_to_map(terms_map: &mut AHashMap<Term, Rc<Term>>, term: Term) -> Rc<Term> {
use std::collections::hash_map::Entry;

match terms_map.entry(term) {
Entry::Occupied(occupied_entry) => occupied_entry.get().clone(),
Entry::Vacant(vacant_entry) => {
let term = vacant_entry.key().clone();
vacant_entry.insert(Rc::new(term)).clone()
}
}
}

/// Computes the sort of a term and adds it to the sort cache.
fn compute_sort(&mut self, term: &Rc<Term>) -> Rc<Term> {
use super::Operator;
Expand Down Expand Up @@ -164,8 +154,8 @@ impl PrimitivePool {
Sort::Function(result)
}
};
let sorted_term = Self::add_term_to_map(&mut self.terms, Term::Sort(result));
self.sorts_cache.insert(term.clone(), sorted_term);
let sort = self.terms.add(Term::Sort(result));
self.sorts_cache.insert(term.clone(), sort);
self.sorts_cache[term].clone()
}

Expand All @@ -174,24 +164,14 @@ impl PrimitivePool {
term: Term,
prior_pools: [&PrimitivePool; N],
) -> Rc<Term> {
use std::collections::hash_map::Entry;

for p in prior_pools {
// If this prior pool has the term
if let Some(entry) = p.terms.get(&term) {
return entry.clone();
}
}

match self.terms.entry(term) {
Entry::Occupied(occupied_entry) => occupied_entry.get().clone(),
Entry::Vacant(vacant_entry) => {
let term = vacant_entry.key().clone();
let term = vacant_entry.insert(Rc::new(term)).clone();
self.compute_sort(&term);
term
}
}
self.add(term)
}

fn sort_with_priorities<const N: usize>(
Expand Down Expand Up @@ -283,7 +263,7 @@ impl TermPool for PrimitivePool {
}

fn add(&mut self, term: Term) -> Rc<Term> {
let term = Self::add_term_to_map(&mut self.terms, term);
let term = self.terms.add(term);
self.compute_sort(&term);
term
}
Expand Down
66 changes: 66 additions & 0 deletions carcara/src/ast/pool/storage.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//* The behaviour of the term pool could be modeled by a hash map from `Term` to `Rc<Term>`, but
//* that would require allocating two copies of each term, one in the key of the hash map, and one
//* inside the `Rc`. Instead, we store a hash set of `Rc<Term>`s, combining the key and the value
//* into a single object. We access this hash set using a `&Term`, and if the entry is present, we
//* clone it; otherwise, we allocate a new `Rc`.

use crate::ast::*;
use std::borrow::Borrow;

/// Since `ast::Rc` intentionally implements hashing and equality by reference (instead of by
/// value), we cannot safely implement `Borrow<Term>` for `Rc<Term>`, so we cannot access a
/// `HashSet<Rc<Term>>` using a `&Term` as a key. To go around that, we use this struct that wraps
/// an `Rc<Term>` and that re-implements hashing and equality by value, meaning we can implement
/// `Borrow<Term>` for it, and use it as the contents of the hash set instead.
#[derive(Debug, Clone, Eq)]
struct ByValue(Rc<Term>);

impl PartialEq for ByValue {
fn eq(&self, other: &Self) -> bool {
self.0.as_ref() == other.0.as_ref()
}
}

impl Hash for ByValue {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.as_ref().hash(state);
}
}

impl Borrow<Term> for ByValue {
fn borrow(&self) -> &Term {
self.0.as_ref()
}
}

#[derive(Debug, Clone, Default)]
pub struct Storage(AHashSet<ByValue>);

impl Storage {
pub fn new() -> Self {
Self::default()
}

pub fn add(&mut self, term: Term) -> Rc<Term> {
// If the `hash_set_entry` feature was stable, this would be much simpler to do using
// `get_or_insert_with` (and would avoid rehashing the term)
match self.0.get(&term) {
Some(t) => t.0.clone(),
None => {
let result = Rc::new(term);
self.0.insert(ByValue(result.clone()));
result
}
}
}

pub fn get(&self, term: &Term) -> Option<&Rc<Term>> {
self.0.get(term).map(|t| &t.0)
}

// This method is only necessary for the hash consing tests
#[cfg(test)]
pub fn into_vec(self) -> Vec<Rc<Term>> {
self.0.into_iter().map(|ByValue(t)| t).collect()
}
}
6 changes: 3 additions & 3 deletions carcara/src/parser/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ fn test_hash_consing() {
.into_iter()
.collect::<AHashSet<&str>>();

assert_eq!(pool.terms.len(), expected.len());

for got in pool.terms.keys() {
let pool_terms = pool.terms.into_vec();
assert_eq!(pool_terms.len(), expected.len());
for got in pool_terms {
let formatted: &str = &format!("{}", got);
assert!(expected.contains(formatted), "{}", formatted);
}
Expand Down

0 comments on commit 4d2a059

Please sign in to comment.