Skip to content

Commit

Permalink
Fix match_with_constraint (#4499)
Browse files Browse the repository at this point in the history
Fixes #4496

---------

Co-authored-by: JianhongZhao <[email protected]>
Co-authored-by: Everett Hildenbrandt <[email protected]>
  • Loading branch information
3 people authored Oct 7, 2024
1 parent d40018b commit 8bf17a3
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 104 deletions.
26 changes: 4 additions & 22 deletions pyk/src/pyk/cterm/cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ def match_with_constraint(self, cterm: CTerm) -> CSubst | None:
if subst is None:
return None

constraint = self._ml_impl(cterm.constraints, map(subst, self.constraints))
source_constraints = [subst(c) for c in self.constraints]
constraints = [c for c in cterm.constraints if c not in source_constraints]

return CSubst(subst=subst, constraints=[constraint])
return CSubst(subst, constraints)

@staticmethod
def _ml_impl(antecedents: Iterable[KInner], consequents: Iterable[KInner]) -> KInner:
Expand Down Expand Up @@ -246,25 +247,6 @@ def remove_useless_constraints(self, keep_vars: Iterable[str] = ()) -> CTerm:
return CTerm(self.config, new_constraints)


def cterm_match(cterm1: CTerm, cterm2: CTerm) -> CSubst | None:
"""Find a substitution which can instantiate `cterm1` to `cterm2`.
Args:
cterm1: `CTerm` to instantiate to `cterm2`.
cterm2: `CTerm` to instantiate `cterm1` to.
Returns:
A `CSubst` which can instantiate `cterm1` to `cterm2`, or `None` if no such `CSubst` exists.
"""
# todo: delete this function and use cterm1.match_with_constraint(cterm2) directly after closing #4496
subst = cterm1.config.match(cterm2.config)
if subst is None:
return None
source_constraints = [subst(c) for c in cterm1.constraints]
constraints = [c for c in cterm2.constraints if c not in source_constraints]
return CSubst(subst, constraints)


def anti_unify(state1: KInner, state2: KInner, kdef: KDefinition | None = None) -> tuple[KInner, Subst, Subst]:
"""Return a generalized state over the two input states.
Expand Down Expand Up @@ -461,5 +443,5 @@ def cterms_anti_unify(
merged_cterm = cterms[0]
for cterm in cterms[1:]:
merged_cterm = merged_cterm.anti_unify(cterm, keep_values, kdef)[0]
csubsts = [not_none(cterm_match(merged_cterm, cterm)) for cterm in cterms]
csubsts = [not_none(merged_cterm.match_with_constraint(cterm)) for cterm in cterms]
return merged_cterm, csubsts
10 changes: 1 addition & 9 deletions pyk/src/pyk/kcfg/kcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,16 @@
from abc import ABC, abstractmethod
from collections.abc import Container
from dataclasses import dataclass, field
from functools import reduce
from threading import RLock
from typing import TYPE_CHECKING, Final, List, Union, cast, final

from pyk.cterm.cterm import cterm_match

from ..cterm import CSubst, CTerm, cterm_build_claim, cterm_build_rule
from ..kast import EMPTY_ATT
from ..kast.inner import KApply
from ..kast.manip import (
bool_to_ml_pred,
extract_lhs,
extract_rhs,
flatten_label,
inline_cell_maps,
minimize_rule_like,
rename_generated_vars,
Expand Down Expand Up @@ -1062,7 +1058,7 @@ def create_split_by_nodes(self, source_id: NodeIdLike, target_ids: Iterable[Node
source = self.node(source_id)
targets = [self.node(nid) for nid in target_ids]
try:
csubsts = [not_none(cterm_match(source.cterm, target.cterm)) for target in targets]
csubsts = [not_none(source.cterm.match_with_constraint(target.cterm)) for target in targets]
except ValueError:
return None
return self.create_split(source.id, zip(target_ids, csubsts, strict=True))
Expand Down Expand Up @@ -1091,10 +1087,6 @@ def split_on_constraints(self, source_id: NodeIdLike, constraints: Iterable[KInn
source = self.node(source_id)
branch_node_ids = [self.create_node(source.cterm.add_constraint(c)).id for c in constraints]
csubsts = [not_none(source.cterm.match_with_constraint(self.node(id).cterm)) for id in branch_node_ids]
csubsts = [
reduce(CSubst.add_constraint, flatten_label('#And', constraint), csubst)
for (csubst, constraint) in zip(csubsts, constraints, strict=True)
]
self.create_split(source.id, zip(branch_node_ids, csubsts, strict=True))
return branch_node_ids

Expand Down
40 changes: 9 additions & 31 deletions pyk/src/pyk/kcfg/minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from functools import reduce
from typing import TYPE_CHECKING

from pyk.cterm import CTerm
from pyk.cterm.cterm import cterms_anti_unify
from pyk.utils import not_none, partition, single
from pyk.utils import partition, single

from .semantics import DefaultSemantics

Expand Down Expand Up @@ -96,26 +95,19 @@ def lift_split_edge(self, b_id: NodeIdLike) -> None:
# Ensure split can be lifted soundly (i.e., that it does not introduce fresh variables)
assert (
len(split_from_b.source_vars.difference(a.free_vars)) == 0
and len(split_from_b.target_vars.difference(split_from_b.source_vars)) == 0
and len(split_from_b.target_vars.difference(split_from_b.source_vars)) == 0 # <-- Can we delete this check?
)
# Create CTerms and CSubsts corresponding to the new targets of the split
new_cterms_with_constraints = [
(CTerm(a.cterm.config, a.cterm.constraints + csubst.constraints), csubst.constraint) for csubst in csubsts
]
# Generate substitutions for new targets, which all exist by construction
new_csubsts = [
not_none(a.cterm.match_with_constraint(cterm)).add_constraint(constraint)
for (cterm, constraint) in new_cterms_with_constraints
]
new_cterms = [csubst(a.cterm) for csubst in csubsts]
# Remove the node `B`, effectively removing the entire initial structure
self.kcfg.remove_node(b_id)
# Create the nodes `[ A #And cond_I | I = 1..N ]`.
ai: list[NodeIdLike] = [self.kcfg.create_node(cterm).id for (cterm, _) in new_cterms_with_constraints]
ai: list[NodeIdLike] = [self.kcfg.create_node(cterm).id for cterm in new_cterms]
# Create the edges `[A #And cond_1 --M steps--> C_I | I = 1..N ]`
for i in range(len(ai)):
self.kcfg.create_edge(ai[i], ci[i], a_to_b.depth, a_to_b.rules)
# Create the split `A --[cond_1, ..., cond_N]--> [A #And cond_1, ..., A #And cond_N]
self.kcfg.create_split(a.id, zip(ai, new_csubsts, strict=True))
self.kcfg.create_split_by_nodes(a.id, ai)

def lift_split_split(self, b_id: NodeIdLike) -> None:
"""Lift a split up a split directly preceding it, joining them into a single split.
Expand All @@ -136,31 +128,17 @@ def lift_split_split(self, b_id: NodeIdLike) -> None:
split_from_a, split_from_b = single(self.kcfg.splits(target_id=b_id)), single(self.kcfg.splits(source_id=b_id))
splits_from_a, splits_from_b = split_from_a.splits, split_from_b.splits
a = split_from_a.source
ci = list(splits_from_b.keys())
list(splits_from_b.keys())
# Ensure split can be lifted soundly (i.e., that it does not introduce fresh variables)
assert (
assert ( # <-- Does it will be a problem when using merging nodes, because it would introduce new variables?
len(split_from_b.source_vars.difference(a.free_vars)) == 0
and len(split_from_b.target_vars.difference(split_from_b.source_vars)) == 0
)
# Get the substitution for `B`, at the same time removing 'B' from the targets of `A`.
csubst_b = splits_from_a.pop(self.kcfg.node(b_id).id)
# Generate substitutions for additional targets `C_I`, which all exist by construction;
# the constraints are cumulative, resulting in `cond_B #And cond_I`
additional_csubsts = [
not_none(a.cterm.match_with_constraint(self.kcfg.node(ci).cterm))
.add_constraint(csubst_b.constraint)
.add_constraint(csubst.constraint)
for ci, csubst in splits_from_b.items()
]
# Create the targets of the new split
ci = list(splits_from_b.keys())
new_splits = zip(
list(splits_from_a.keys()) + ci, list(splits_from_a.values()) + additional_csubsts, strict=True
)
# Remove the node `B`, thereby removing the two splits as well
splits_from_a.pop(self.kcfg.node(b_id).id)
self.kcfg.remove_node(b_id)
# Create the new split `A --[..., cond_B #And cond_1, ..., cond_B #And cond_N, ...]--> [..., C_1, ..., C_N, ...]`
self.kcfg.create_split(a.id, new_splits)
self.kcfg.create_split_by_nodes(a.id, list(splits_from_a.keys()) + list(splits_from_b.keys()))

def lift_splits(self) -> bool:
"""Perform all possible split liftings.
Expand Down
32 changes: 2 additions & 30 deletions pyk/src/tests/integration/proof/test_refute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pyk.prelude.ml import is_top, mlEqualsTrue
from pyk.proof import APRProof, APRProver, ImpliesProver, ProofStatus, RefutationProof
from pyk.testing import KCFGExploreTest, KProveTest
from pyk.utils import not_none, single
from pyk.utils import single

from ..utils import K_FILES

Expand Down Expand Up @@ -327,35 +327,7 @@ def test_apr_proof_refute_node_multiple_constraints(
cfg.create_node(CTerm(config, [l_le_0, m_gt_0]))
cfg.create_node(CTerm(config, [l_le_0, m_le_0]))

proof.kcfg.create_split(
1,
[
(
3,
not_none(cfg.node(1).cterm.match_with_constraint(cfg.node(3).cterm))
.add_constraint(l_gt_0)
.add_constraint(m_gt_0),
),
(
4,
not_none(cfg.node(1).cterm.match_with_constraint(cfg.node(4).cterm))
.add_constraint(l_gt_0)
.add_constraint(m_le_0),
),
(
5,
not_none(cfg.node(1).cterm.match_with_constraint(cfg.node(5).cterm))
.add_constraint(l_le_0)
.add_constraint(m_gt_0),
),
(
6,
not_none(cfg.node(1).cterm.match_with_constraint(cfg.node(6).cterm))
.add_constraint(l_le_0)
.add_constraint(m_gt_0),
),
],
)
proof.kcfg.create_split_by_nodes(1, [3, 4, 5, 6])

# When
prover.advance_proof(proof, max_iterations=4)
Expand Down
16 changes: 8 additions & 8 deletions pyk/src/tests/unit/kcfg/test_minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_lifting_functions_automatic() -> None:
[
(node_15, to_csubst_node(x_node(1), node_19, [x_lt_0])),
(node_18, to_csubst_node(x_node(1), node_18, [x_ge_0, x_ge_5])),
(node_19, to_csubst_node(x_node(1), node_19, [x_ge_0, x_lt_5])),
(node_19, to_csubst_node(x_node(1), node_19, [x_lt_5, x_ge_0])),
],
)
)
Expand Down Expand Up @@ -248,7 +248,7 @@ def test_minimize_01() -> None:
[
(node_15, to_csubst_node(x_node(1), node_19, [x_lt_0])),
(node_18, to_csubst_node(x_node(1), node_18, [x_ge_0, x_ge_5])),
(node_19, to_csubst_node(x_node(1), node_19, [x_ge_0, x_lt_5])),
(node_19, to_csubst_node(x_node(1), node_19, [x_lt_5, x_ge_0])),
],
)
)
Expand Down Expand Up @@ -355,41 +355,41 @@ def x_lt(n: int) -> KApply:
'┃\n'
'┣━━┓ subst: .Subst\n'
'┃ ┃ constraint:\n'
'┃ ┃ _<Int_ ( X , 64 )\n'
'┃ ┃ _>=Int_ ( X , 0 )\n'
'┃ ┃ _>=Int_ ( X , 32 )\n'
'┃ ┃ _<Int_ ( X , 64 )\n'
'┃ │\n'
'┃ └─ 9 (leaf)\n'
'┃\n'
'┣━━┓ subst: .Subst\n'
'┃ ┃ constraint:\n'
'┃ ┃ _>=Int_ ( X , 0 )\n'
'┃ ┃ _<Int_ ( X , 32 )\n'
'┃ ┃ _>=Int_ ( X , 0 )\n'
'┃ ┃ _>=Int_ ( X , 16 )\n'
'┃ │\n'
'┃ └─ 10 (leaf)\n'
'┃\n'
'┣━━┓ subst: .Subst\n'
'┃ ┃ constraint:\n'
'┃ ┃ _>=Int_ ( X , 0 )\n'
'┃ ┃ _<Int_ ( X , 32 )\n'
'┃ ┃ _<Int_ ( X , 16 )\n'
'┃ ┃ _<Int_ ( X , 32 )\n'
'┃ ┃ _>=Int_ ( X , 0 )\n'
'┃ │\n'
'┃ └─ 11 (leaf)\n'
'┃\n'
'┣━━┓ subst: .Subst\n'
'┃ ┃ constraint:\n'
'┃ ┃ _<Int_ ( X , 0 )\n'
'┃ ┃ _>=Int_ ( X , -32 )\n'
'┃ ┃ _>=Int_ ( X , -16 )\n'
'┃ ┃ _>=Int_ ( X , -32 )\n'
'┃ │\n'
'┃ └─ 12 (leaf)\n'
'┃\n'
'┣━━┓ subst: .Subst\n'
'┃ ┃ constraint:\n'
'┃ ┃ _<Int_ ( X , 0 )\n'
'┃ ┃ _>=Int_ ( X , -32 )\n'
'┃ ┃ _<Int_ ( X , -16 )\n'
'┃ ┃ _>=Int_ ( X , -32 )\n'
'┃ │\n'
'┃ └─ 13 (leaf)\n'
'┃\n'
Expand Down
28 changes: 27 additions & 1 deletion pyk/src/tests/unit/test_cterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pyk.kast.outer import KClaim
from pyk.prelude.k import GENERATED_TOP_CELL, K
from pyk.prelude.kbool import TRUE
from pyk.prelude.kint import INT, intToken
from pyk.prelude.kint import INT, geInt, intToken
from pyk.prelude.ml import mlAnd, mlEquals, mlEqualsTrue, mlTop

from .utils import a, b, c, f, g, ge_ml, h, k, lt_ml, x, y, z
Expand Down Expand Up @@ -75,6 +75,32 @@ def test_no_cterm_match(term: KInner, pattern: KInner) -> None:
assert subst is None


MATCH_WITH_CONSTRAINT_TEST_DATA: Final = (
(CTerm(k(x)), CTerm(k(x))),
(CTerm(k(x)), CTerm(k(y))),
(CTerm(k(x)), CTerm(k(y), (mlEqualsTrue(geInt(y, intToken(0))),))),
(
CTerm(k(x), (mlEqualsTrue(geInt(y, intToken(0))),)),
CTerm(
k(y),
(
mlEqualsTrue(geInt(y, intToken(0))),
mlEqualsTrue(geInt(y, intToken(5))),
),
),
),
)


@pytest.mark.parametrize('t1, t2', MATCH_WITH_CONSTRAINT_TEST_DATA, ids=count())
def test_cterm_match_with_constraint(t1: CTerm, t2: CTerm) -> None:
# When
c_subst1 = t1.match_with_constraint(t2)

# Then
assert c_subst1 is not None and c_subst1.apply(t1) == t2


BUILD_RULE_TEST_DATA: Final = (
(
T(k(KVariable('K_CELL')), mem(KVariable('MEM_CELL'))),
Expand Down
7 changes: 4 additions & 3 deletions pyk/src/tests/unit/test_kcfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import TYPE_CHECKING, Final

import pytest
from unit.utils import ge_ml, k, lt_ml

from pyk.cterm import CSubst, CTerm
from pyk.kast.inner import KApply, KRewrite, KToken, KVariable, Subst
Expand All @@ -17,6 +16,7 @@
from pyk.utils import not_none, single

from .mock_kprint import MockKPrint
from .utils import ge_ml, k, lt_ml

if TYPE_CHECKING:
from collections.abc import Iterable
Expand All @@ -27,8 +27,9 @@


def to_csubst_cterm(term_1: CTerm, term_2: CTerm, constraints: Iterable[KInner]) -> CSubst:
csubst = term_1.match_with_constraint(term_2)
assert csubst is not None
subst = term_1.config.match(term_2.config)
assert subst is not None
csubst = CSubst(subst, [])
for constraint in constraints:
csubst = csubst.add_constraint(constraint)
return csubst
Expand Down

0 comments on commit 8bf17a3

Please sign in to comment.