diff --git a/pantograph/search.py b/pantograph/search.py index 2b25685..a04ae32 100644 --- a/pantograph/search.py +++ b/pantograph/search.py @@ -1,9 +1,10 @@ +import random from abc import abstractmethod import time from dataclasses import dataclass -from typing import Optional, Self +from typing import Optional, Self, List import collections, unittest - +from math import log, sqrt from pantograph.server import Server, TacticFailure, ServerError from pantograph.expr import Expr, Tactic, GoalState @@ -15,12 +16,20 @@ class SearchState: parent: Optional[Self] parent_goal_id: Optional[int] priorities: list[float] + children: Optional[List[Self]] = None + tested_tactics: Optional[List[Tactic]] = None + total_value: Optional[float] = None tactic_feedback: Optional[str] = None def __post_init__(self): assert len(self.priorities) == len(self.goal_state.goals) self.solved = [False for _ in self.goal_state.goals] self.trials = [0 for _ in self.goal_state.goals] + self.tested_tactics = [] if self.tested_tactics is None else self.tested_tactics + self.children = [] if self.children is None else self.children + self.visit_count = 1 + self.exhausted = False + self.subtree_exhausted = False @property def next_goal_id(self) -> int: @@ -180,6 +189,148 @@ def search(self, ) +class MCTSAgent(Agent): + """ + An agent interface for proof search using monte carlo tree search + """ + + @abstractmethod + def next_tactic( + self, + state: GoalState, + goal_id: int, + tested: Optional[List[Tactic]] = None, + ) -> Optional[Tactic]: + """ + Implement this function to generate the next tactic for a goal given tactics already tested + """ + + @abstractmethod + def reset(self): + """ + Called after search + """ + + @abstractmethod + def estimate(self, state: SearchState) -> SearchState: + """ + Implement this function to estimate the value of a state + """ + + @abstractmethod + def select(self, state: SearchState) -> list[SearchState]: + """ + Implement this function to select the best node within the subtree of the state. + Returns the path to the selected node from the given state. + """ + + def backup(self, states: list[SearchState], value: float): + """ + Backup value of the state at the end of the states list. + """ + for state in states: + state.total_value += value + state.visit_count += 1 + state.subtree_exhausted = all(child.subtree_exhausted for child in state.children) and state.exhausted + + def search(self, + server: Server, + goal_state: GoalState, + max_steps: int = 100, + max_trials_per_goal: int = 5, + verbose: bool = False) -> SearchResult: + """ + Executes proof search on this state + """ + + assert server.is_automatic(), "Search must be run in automatic mode" + + n_goals_root = len(goal_state.goals) + time_start = time.time() + + initial_state = SearchState( + goal_state=goal_state, + parent=None, + parent_goal_id=None, + priorities=[0.0 for _ in goal_state.goals] + ) + initial_state = self.estimate(initial_state) + search_root = initial_state + + for i_step in range(max_steps): + search_trajectory = self.select(search_root) + search_state = search_trajectory[-1] + assert isinstance(search_state, SearchState) + + if search_state.is_solved: + return SearchResult( + n_goals_root=n_goals_root, + duration=time.time() - time_start, + success=True, + steps=i_step, + ) + + # Find the unsolved goal with the highest priority + goal_id = search_state.next_goal_id + + if search_state.trials[goal_id] > max_trials_per_goal: + # force halt the search + tactic = None + else: + # Generate tactic for this goal + tactic = self.next_tactic(search_state.goal_state, goal_id, search_state.tested_tactics) + + if verbose: + print(f"Next tactic: {tactic}") + if not tactic: + # resets the feedback + search_state.tactic_feedback = None + search_state.exhausted = True + search_state.subtree_exhausted = all(child.subtree_exhausted for child in search_state.children) + continue + assert tactic not in search_state.tested_tactics, "Tactic already seen!" + search_state.tested_tactics.append(tactic) + + try: + search_state.trials[goal_id] += 1 + state = search_state.goal_state + if verbose: + print(f"{state.state_id}.{goal_id}: {tactic} on {search_state.goal_state.goals[goal_id]}") + next_goal_state = server.goal_tactic(search_state.goal_state, goal_id, tactic) + # Generate priorities for the next goal state + priorities = [0.0 for _ in next_goal_state.goals] \ + if len(next_goal_state.goals) <= 1 else \ + self.guidance(next_goal_state) + parent = -1 + next_state = SearchState( + goal_state=next_goal_state, + parent=parent, + parent_goal_id=goal_id, + priorities=priorities + ) + next_state = self.estimate(next_state) + search_state.children.append(next_state) + self.backup(search_trajectory, next_state.total_value) + except TacticFailure as t: + if verbose: + print(f"Tactic failed: {t}") + search_state.tactic_feedback = str(t) + # try the next tactic. this one failed + except ServerError as e: + raise RuntimeError(f"While executing tactic: {tactic}") from e + + if verbose: + print("Search iteration limit exhausted") + + self.reset() + return SearchResult( + n_goals_root=n_goals_root, + duration=time.time() - time_start, + success=False, + steps=max_steps, + ) + + class DumbAgent(Agent): def __init__(self): @@ -221,6 +372,79 @@ def next_tactic( self.goal_tactic_id_map[key] = i + 1 return tactics[i] +class DumbMCTSAgent(MCTSAgent): + def __init__(self): + super().__init__() + + self.goal_tactic_id_map = collections.defaultdict(lambda : 0) + self.intros = [ + "intro", + ] + self.tactics = [ + "intro h", + "cases h", + "apply Or.inl", + "apply Or.inr", + ] + self.no_space_tactics = [ + "assumption", + ] + self.c = 0.6 + + def estimate(self, state: SearchState) -> SearchState: + state.total_value = random.random() + return state + + def select(self, state: SearchState) -> list[SearchState]: + """ + UCB scoring with taking the current state as one option, i.e. one child + """ + state_trajectory = [state] + current_state = state + current_state_ucb = (state.total_value / state.visit_count) + self.c * sqrt((log(state.visit_count) / state.visit_count)) + while current_state.children: + avg_val = [child.total_value / child.visit_count for child in current_state.children] + visit_portions = [sqrt(log(current_state.visit_count) / child.visit_count) for child in current_state.children] + ucbs = [avg + self.c * visit for avg, visit in zip(avg_val, visit_portions, strict=True)] + child_idcs = [idx for idx in range(len(current_state.children)) if not current_state.children[idx].subtree_exhausted] + if not child_idcs: + return state_trajectory + child_idx = child_idcs[0] + for i in child_idcs: + if ucbs[i] > ucbs[child_idx]: + child_idx = i + if ucbs[child_idx] < current_state_ucb and not current_state.exhausted: + return state_trajectory + current_state_ucb = ucbs[child_idx] + current_state = current_state.children[child_idx] + state_trajectory.append(current_state) + return state_trajectory + + def next_tactic( + self, + state: GoalState, + goal_id: int, + tested: Optional[List[Tactic]] = None + ) -> Optional[Tactic]: + key = (state.state_id, goal_id) + i = self.goal_tactic_id_map[key] + target = state.goals[goal_id].target + if target.startswith('∀'): + tactics = self.intros + elif ' ' in target: + tactics = self.tactics + else: + tactics = self.no_space_tactics + + if i >= len(tactics): + return None + self.goal_tactic_id_map[key] = i + 1 + while tactics[i] in tested: + i += 1 + if i >= len(tactics): + return None + return tactics[i] + class TestSearch(unittest.TestCase): @@ -246,6 +470,31 @@ def test_solve_big(self): verbose=False) self.assertTrue(flag) +class TestMCTSSearch(unittest.TestCase): + + def test_solve(self): + + server = Server() + agent = DumbMCTSAgent() + goal_state = server.goal_start("∀ (p q: Prop), p -> p") + flag = agent.search( + server=server, + goal_state=goal_state, + verbose=False) + #flag = agent.search(server=server, target="∀ (p q: Prop), Or p q -> Or q p", verbose=True) + self.assertTrue(flag) + def test_solve_big(self): + + server = Server() + agent = DumbMCTSAgent() + goal_state = server.goal_start("∀ (p q: Prop), Or p q -> Or q p") + flag = agent.search( + server=server, + goal_state=goal_state, + max_steps=200, + verbose=False) + self.assertTrue(flag) + if __name__ == '__main__': unittest.main()