Skip to content

Commit

Permalink
WIP sharing term
Browse files Browse the repository at this point in the history
  • Loading branch information
NotBad4U committed Oct 20, 2024
1 parent 1ab2c39 commit cdf2146
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 43 deletions.
74 changes: 53 additions & 21 deletions carcara/src/lambdapi/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use crate::ast::AnchorArg;
#[allow(const_item_mutation)]
use crate::ast::{
polyeq, Operator, ProblemPrelude, Proof as ProofElaborated, ProofCommand, ProofIter,
ProofStep as AstProofStep, Rc, Sort, Subproof, Term as AletheTerm,
polyeq, pool::{self, TermPool}, AnchorArg, Binder, Operator, PrimitivePool, ProblemPrelude, Proof as ProofElaborated, ProofCommand, ProofIter, ProofStep as AstProofStep, Rc, Sort, Subproof, Term as AletheTerm
};
use indexmap::IndexMap;
use itertools::Itertools;
Expand All @@ -12,7 +9,9 @@ use try_match::unwrap_match;
use std::{
collections::{HashMap, HashSet, VecDeque},
fmt::{self},
ops::Deref,
time::Duration,
vec,
};

mod dsl;
Expand Down Expand Up @@ -47,19 +46,24 @@ pub struct Context {
/// with the `:named` annotation. This feature make step more compact and easier to debug.
/// We do not propose an option to disable this feature because it is enough to run Carcara translation
/// by providing a proof file without `:named` annotation.
term_indices: IndexMap<Rc<AletheTerm>, usize>,
term_sharing: IndexMap<Rc<AletheTerm>, Rc<Term>>,
pub term_indices: IndexMap<Rc<AletheTerm>, (usize, String)>,
pub term_sharing: IndexMap<Rc<AletheTerm>, (String, Term)>,

/// Dependencies of premises as a map Index ↦ location, depth, [Index] where Index represent
/// the location of the premise in the proof.
deps: HashMap<String, (usize, usize, HashSet<usize>)>,
index: usize,
pub global_variables: HashSet<Rc<AletheTerm>>,
pub pool: PrimitivePool,
}

impl Context {
/// Convert dagify subexpression into `Term::TermId` otherwise just apply a canonical conversion
fn get_or_convert(&self, term: &Rc<AletheTerm>) -> Term {
Term::from(term)
term::conv(term, self)
// self.term_sharing
// .get(term)
// .map_or(Term::from(term), |(name, _def)| Term::from(name))
}
}

Expand Down Expand Up @@ -150,26 +154,47 @@ fn gen_required_module() -> Vec<Command> {
]
}

fn gen_shared_term(ctx: &Context) -> Vec<Command> {
ctx.term_indices
.iter()
.filter(|(_, (counter, _))| *counter >= 2)
.map(|(t, (..))| ctx.term_sharing[t].clone())
.map(|(id, term)| Command::Definition(id.to_string(), vec![], None, Some(term)))
.collect_vec()
}

pub fn produce_lambdapi_proof<'a>(
prelude: ProblemPrelude,
proof_elaborated: ProofElaborated,
mut pool: pool::PrimitivePool,
) -> TradResult<ProofFile> {
let mut proof_file = ProofFile::new();

proof_file.requires = gen_required_module();

let global_variables: HashSet<_> =
prelude
.function_declarations
.iter()
.map(|var| pool.add(var.clone().into()))
.collect();

proof_file.definitions = translate_prelude(prelude);

let mut context = Context::default();

context.global_variables = global_variables;

let commands = translate_commands(
&mut context,
&mut proof_elaborated.iter(),
0,
|id, t, ps| Command::Symbol(None, normalize_name(id), vec![], t, Some(Proof(ps))),
)?;

println!("{:#?}", context.term_indices);
let shared_terms = gen_shared_term(&context);

proof_file.definitions.extend(shared_terms);

proof_file.content.extend(commands);

Expand Down Expand Up @@ -533,7 +558,10 @@ fn translate_subproof<'a>(
ProofCommand::Step(AstProofStep { id, clause, rule,.. }) => (normalize_name(id), clause, rule)
);

let clause = clause.iter().map(From::from).collect_vec();
let clause = clause
.iter()
.map(|t| context.get_or_convert(t))
.collect_vec();

let mut fresh_ctx = Context::default();
fresh_ctx.deps = context.deps.clone();
Expand Down Expand Up @@ -566,16 +594,16 @@ fn translate_subproof<'a>(

let last_step_id = unwrap_match!(commands.get(commands.len() - 2), Some(ProofCommand::Step(AstProofStep{id, ..})) => normalize_name(id));

let bind_lemma = match clause.first() {
Some(Term::Alethe(LTerm::Eq(l, r)))
if matches!(**l, Term::Alethe(LTerm::Forall(_, _)))
&& matches!(**r, Term::Alethe(LTerm::Forall(_, _))) =>
let bind_lemma = match subproof.clause().first().expect("clause is empty").deref() {
AletheTerm::Op(Operator::Equals, args)
if matches!(args[0].deref(), AletheTerm::Binder(Binder::Forall, _, _))
&& matches!(args[1].deref(), AletheTerm::Binder(Binder::Forall, _, _)) =>
{
"bind_∀"
}
Some(Term::Alethe(LTerm::Eq(l, r)))
if matches!(**l, Term::Alethe(LTerm::Exist(_, _)))
&& matches!(**r, Term::Alethe(LTerm::Exist(_, _))) =>
AletheTerm::Op(Operator::Equals, args)
if matches!(args[0].deref(), AletheTerm::Binder(Binder::Exists, _, _))
&& matches!(args[1].deref(), AletheTerm::Binder(Binder::Exists, _, _)) =>
{
"bind_∃"
}
Expand Down Expand Up @@ -752,14 +780,18 @@ where
let clause = command.clause();
clause
.into_iter()
.for_each(|c| c.visit(&mut ctx.term_indices));
.for_each(|c| c.visit(ctx));

match command {
ProofCommand::Assume { id, term } => {
ctx.deps
.insert(normalize_name(&id), (ctx.index, depth, HashSet::new()));

proof_steps.push(f(id.into(), term::clauses(vec![Term::from(term)]), admit()))
proof_steps.push(f(
id.into(),
term::clauses(vec![ctx.get_or_convert(term)]),
admit(),
))
}
ProofCommand::Step(AstProofStep {
id,
Expand All @@ -785,7 +817,7 @@ where
let proof = translate_resolution(proof_iter, premises, args)?;

let clauses = Term::Alethe(LTerm::Proof(Box::new(Term::Alethe(LTerm::Clauses(
clause.into_iter().map(|a| Term::from(a)).collect(),
clause.into_iter().map(|a| ctx.get_or_convert(a)).collect(),
)))));

proof_steps.push(f(normalize_name(id), clauses, proof));
Expand All @@ -796,7 +828,7 @@ where
ctx.deps
.insert(normalize_name(&id), (ctx.index, depth, HashSet::new()));

let terms: Vec<Term> = clause.into_iter().map(|a| Term::from(a)).collect();
let terms: Vec<Term> = clause.into_iter().map(|a| ctx.get_or_convert(a)).collect();

let proof_script = translate_rare_simp(args);

Expand All @@ -811,7 +843,7 @@ where
ProofCommand::Step(AstProofStep { id, clause, rule, .. }) if rule.contains("simp") => {
ctx.deps
.insert(normalize_name(&id), (ctx.index, depth, HashSet::new()));
let terms: Vec<Term> = clause.into_iter().map(|a| Term::from(a)).collect();
let terms: Vec<Term> = clause.into_iter().map(|a| ctx.get_or_convert(a)).collect();

let proof_script = translate_simplify_step(rule);

Expand Down
172 changes: 150 additions & 22 deletions carcara/src/lambdapi/term.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::ast::{
Binder as AletheBinder, BindingList, Constant, Operator, ParamOperator, ProofStep, Rc, Sort,
SortedVar, Term as AletheTerm,
pool, Binder as AletheBinder, BindingList, Constant, Operator, Rc, Sort, SortedVar,
Term as AletheTerm, TermPool,
};
use indexmap::IndexMap;
use itertools::Itertools;
Expand All @@ -12,6 +12,7 @@ use std::{fmt, usize, vec};
const WHITE_SPACE: &'static str = " ";

use super::proof::Proof;
use super::Context;

/// The BNF grammar of Lambdapi is in [lambdapi.bnf](https://raw.githubusercontent.com/Deducteam/lambdapi/master/doc/lambdapi.bnf).
/// Data structure of this file try to represent this grammar.
Expand Down Expand Up @@ -265,6 +266,92 @@ impl<S: Into<String>> From<S> for Term {
}
}

pub fn conv(term: &Rc<AletheTerm>, ctx: &crate::lambdapi::Context) -> Term {
ctx.term_sharing.get(term).map_or_else(
|| match term.deref() {
AletheTerm::Sort(_) => Term::from(term),
AletheTerm::App(f, args) => {
let mut func = vec![conv(f, ctx)];
let mut args: Vec<Term> = args.into_iter().map(|a| conv(a, ctx)).collect();
func.append(&mut args);
Term::Terms(func)
}
AletheTerm::Op(operator, args) => {
let args = args
.into_iter()
.map(|a| conv(a, ctx))
.collect::<VecDeque<_>>();
return match operator {
Operator::Not => Term::Alethe(LTerm::Neg(Some(Box::new(
args.front().map(|a| a.clone()).unwrap(),
)))),
Operator::Or => Term::Alethe(LTerm::NOr(args.into())),
Operator::Equals => Term::Alethe(LTerm::Eq(
Box::new(args[0].clone()),
Box::new(args[1].clone()),
)),
Operator::And => Term::Alethe(LTerm::NAnd(args.into())),
Operator::Implies => Term::Alethe(LTerm::Implies(
Box::new(args[0].clone()),
Box::new(args[1].clone()),
)),
Operator::Distinct => Term::Alethe(LTerm::Distinct(ListLP(
args.into_iter().map(Into::into).collect_vec(),
))),
Operator::Sub if args.len() == 2 => {
Term::Terms(vec![args[0].clone(), "-".into(), args[1].clone()])
}
Operator::Sub if args.len() == 1 => {
Term::Terms(vec!["~".into(), args[0].clone()])
}
Operator::Add => {
Term::Terms(vec![args[0].clone(), "+".into(), args[1].clone()])
}
Operator::GreaterEq => {
Term::Terms(vec![args[0].clone(), "≥".into(), args[1].clone()])
}
Operator::GreaterThan => {
Term::Terms(vec![args[0].clone(), ">".into(), args[1].clone()])
}
Operator::LessEq => {
Term::Terms(vec![args[0].clone(), "≤".into(), args[1].clone()])
}
Operator::LessThan => {
Term::Terms(vec![args[0].clone(), "<".into(), args[1].clone()])
}
Operator::Mult => {
Term::Terms(vec![args[0].clone(), "×".into(), args[1].clone()])
}
Operator::RareList => {
Term::Terms(args.into_iter().map(From::from).collect_vec())
}
Operator::True => Term::Alethe(LTerm::True),
Operator::False => Term::Alethe(LTerm::False),
o => todo!("Operator {:?}", o),
};
}
AletheTerm::Let(..) => todo!("let term"),
AletheTerm::Binder(AletheBinder::Forall, bs, t) => {
Term::Alethe(LTerm::Forall(Bindings::from(bs), Box::new(conv(t, ctx))))
}
AletheTerm::Binder(AletheBinder::Exists, bs, t) => {
Term::Alethe(LTerm::Exist(Bindings::from(bs), Box::new(conv(t, ctx))))
}
AletheTerm::Binder(AletheBinder::Choice, bs, t) => {
Term::Alethe(LTerm::Choice(Bindings::from(bs), Box::new(conv(t, ctx))))
}
AletheTerm::Var(id, _term) => Term::TermId(id.to_string()),
AletheTerm::Const(c) => match c {
Constant::Integer(i) => Term::Nat(i.to_u32().unwrap()), //FIXME: better support of number
Constant::String(s) => Term::from(s),
c => unimplemented!("{}", c),
},
e => todo!("{:#?}", e),
},
|(name, _def)| Term::from(name),
)
}

impl From<&Rc<AletheTerm>> for Term {
fn from(term: &Rc<AletheTerm>) -> Self {
match term.deref() {
Expand Down Expand Up @@ -543,41 +630,82 @@ pub fn clauses(terms: Vec<Term>) -> Term {
}

pub trait Visitor {
fn visit(&self, map: &mut IndexMap<Rc<AletheTerm>, usize>);
fn visit(&self, ctx: &mut Context);
}

impl Visitor for Rc<AletheTerm> {
fn visit(&self, map: &mut IndexMap<Rc<AletheTerm>, usize>) {
fn visit(&self, ctx: &mut Context) {
match self.deref() {
AletheTerm::Const(_)
| AletheTerm::Var(..)
| AletheTerm::Sort(_)
| AletheTerm::ParamOp { .. }
| AletheTerm::Let(..) => {}
| AletheTerm::Let(..)
| AletheTerm::Op(Operator::True, _)
| AletheTerm::Op(Operator::False, _) => {}
AletheTerm::Op(_, ops) => {
if let Some(count) = map.get_mut(self) {
*count = *count + 1;
} else {
map.insert(self.clone(), 1);
if self.is_closed(&mut ctx.pool, &ctx.global_variables) {
if let Some((count, t)) = ctx.term_indices.get_mut(self) {
*count = *count + 1;
if *count >= 1 {
ctx.term_sharing
.insert(self.clone(), (t.to_string(), self.into()));
}
} else {
ctx.term_indices.insert(
self.clone(),
(1, format!("p_{}", ctx.term_indices.len() + 1)),
);
}
}
ops.into_iter().for_each(|op| op.visit(map));
ops.into_iter().for_each(|op| op.visit(ctx));
}
AletheTerm::App(o, ops) => {
if let Some(count) = map.get_mut(self) {
*count = *count + 1;
} else {
map.insert(self.clone(), 1);
if self.is_closed(&mut ctx.pool, &ctx.global_variables) {
if let Some((count, t)) = ctx.term_indices.get_mut(self) {
*count = *count + 1;
if *count >= 1 {
ctx.term_sharing
.insert(self.clone(), (t.to_string(), self.into()));
}
} else {
ctx.term_indices.insert(
self.clone(),
(1, format!("p_{}", ctx.term_indices.len() + 1)),
);
}
}
o.visit(map);
ops.into_iter().for_each(|op| op.visit(map));
o.visit(ctx);
ops.into_iter().for_each(|op| op.visit(ctx));
}
AletheTerm::Binder(_, _, t) => {
if let Some(count) = map.get_mut(self) {
*count = *count + 1;
} else {
map.insert(self.clone(), 1);
AletheTerm::Binder(_, bs, t) => {
let bs_bindinds = bs.into_iter().map(|(name, _)| name).collect_vec();
let free_vars = ctx
.pool
.free_vars(self)
.into_iter()
.filter(|var| !ctx.global_variables.contains(var))
.filter(|var| match var.deref() {
AletheTerm::Var(var, _) => bs_bindinds.contains(&var) == false,
_ => false,
})
.collect_vec();

if free_vars.is_empty() {
if let Some((count, t)) = ctx.term_indices.get_mut(self) {
*count = *count + 1;
if *count >= 1 {
ctx.term_sharing
.insert(self.clone(), (t.to_string(), self.into()));
}
} else {
ctx.term_indices.insert(
self.clone(),
(1, format!("p_{}", ctx.term_indices.len() + 1)),
);
}
}
t.visit(map);
t.visit(ctx);
}
}
}
Expand Down
Loading

0 comments on commit cdf2146

Please sign in to comment.