Skip to content

Commit

Permalink
Prevent reaching unreachable markers in AstCache
Browse files Browse the repository at this point in the history
  • Loading branch information
twizmwazin committed Nov 4, 2024
1 parent 95d336f commit 3f1f1ee
Showing 1 changed file with 79 additions and 53 deletions.
132 changes: 79 additions & 53 deletions crates/clarirs_core/src/ast/astcache.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::{Arc, RwLock, Weak};
use std::sync::{RwLock, Weak};

use ahash::HashMap;

Expand All @@ -12,6 +12,60 @@ enum AstCacheValue<'c> {
String(Weak<AstNode<'c, StringOp<'c>>>),
}

impl<'c> AstCacheValue<'c> {
fn as_bool(&self) -> Option<BoolAst<'c>> {
match self {
AstCacheValue::Boolean(weak) => weak.upgrade(),
_ => None,
}
}

fn as_bv(&self) -> Option<BitVecAst<'c>> {
match self {
AstCacheValue::BitVec(weak) => weak.upgrade(),
_ => None,
}
}

fn as_float(&self) -> Option<FloatAst<'c>> {
match self {
AstCacheValue::Float(weak) => weak.upgrade(),
_ => None,
}
}

fn as_string(&self) -> Option<StringAst<'c>> {
match self {
AstCacheValue::String(weak) => weak.upgrade(),
_ => None,
}
}
}

impl<'c> From<BoolAst<'c>> for AstCacheValue<'c> {
fn from(ast: BoolAst<'c>) -> Self {
AstCacheValue::Boolean(BoolAst::downgrade(&ast))
}
}

impl<'c> From<BitVecAst<'c>> for AstCacheValue<'c> {
fn from(ast: BitVecAst<'c>) -> Self {
AstCacheValue::BitVec(BitVecAst::downgrade(&ast))
}
}

impl<'c> From<FloatAst<'c>> for AstCacheValue<'c> {
fn from(ast: FloatAst<'c>) -> Self {
AstCacheValue::Float(FloatAst::downgrade(&ast))
}
}

impl<'c> From<StringAst<'c>> for AstCacheValue<'c> {
fn from(ast: StringAst<'c>) -> Self {
AstCacheValue::String(StringAst::downgrade(&ast))
}
}

#[derive(Debug, Default)]
pub struct AstCache<'c> {
inner: RwLock<HashMap<u64, AstCacheValue<'c>>>,
Expand All @@ -24,20 +78,13 @@ impl<'c> AstCache<'c> {
f: F,
) -> BoolAst<'c> {
let mut inner = self.inner.write().unwrap();
let entry = inner
.entry(hash)
.or_insert_with(|| AstCacheValue::Boolean(Weak::new()));
match entry {
AstCacheValue::Boolean(weak) => {
if let Some(arc) = weak.upgrade() {
arc
} else {
let arc = f();
*entry = AstCacheValue::Boolean(Arc::downgrade(&arc));
arc
}
match inner.get(&hash).and_then(|v| v.as_bool()) {
Some(value) => value,
None => {
let this = f();
inner.insert(hash, this.clone().into());
this
}
_ => unreachable!(),
}
}

Expand All @@ -47,20 +94,13 @@ impl<'c> AstCache<'c> {
f: F,
) -> BitVecAst<'c> {
let mut inner = self.inner.write().unwrap();
let entry = inner
.entry(hash)
.or_insert_with(|| AstCacheValue::BitVec(Weak::new()));
match entry {
AstCacheValue::BitVec(weak) => {
if let Some(arc) = weak.upgrade() {
arc
} else {
let arc = f();
*entry = AstCacheValue::BitVec(Arc::downgrade(&arc));
arc
}
match inner.get(&hash).and_then(|v| v.as_bv()) {
Some(value) => value,
None => {
let this = f();
inner.insert(hash, this.clone().into());
this
}
_ => unreachable!(),
}
}

Expand All @@ -70,20 +110,13 @@ impl<'c> AstCache<'c> {
f: F,
) -> FloatAst<'c> {
let mut inner = self.inner.write().unwrap();
let entry = inner
.entry(hash)
.or_insert_with(|| AstCacheValue::Float(Weak::new()));
match entry {
AstCacheValue::Float(weak) => {
if let Some(arc) = weak.upgrade() {
arc
} else {
let arc = f();
*entry = AstCacheValue::Float(Arc::downgrade(&arc));
arc
}
match inner.get(&hash).and_then(|v| v.as_float()) {
Some(value) => value,
None => {
let this = f();
inner.insert(hash, this.clone().into());
this
}
_ => unreachable!(),
}
}

Expand All @@ -93,20 +126,13 @@ impl<'c> AstCache<'c> {
f: F,
) -> StringAst<'c> {
let mut inner = self.inner.write().unwrap();
let entry = inner
.entry(hash)
.or_insert_with(|| AstCacheValue::String(Weak::new()));
match entry {
AstCacheValue::String(weak) => {
if let Some(arc) = weak.upgrade() {
arc
} else {
let arc = f();
*entry = AstCacheValue::String(Arc::downgrade(&arc));
arc
}
match inner.get(&hash).and_then(|v| v.as_string()) {
Some(value) => value,
None => {
let this = f();
inner.insert(hash, this.clone().into());
this
}
_ => unreachable!(),
}
}
}

0 comments on commit 3f1f1ee

Please sign in to comment.