diff --git a/src/lib/Templates.ml b/src/lib/Templates.ml index 20ebb2ef..91a4c7ef 100644 --- a/src/lib/Templates.ml +++ b/src/lib/Templates.ml @@ -585,6 +585,7 @@ and instantiate_state_factor (param_map: var_map): unexpanded_parsed_state_predi | Unexpanded_Parsed_state_predicate_factor_NOT f -> Unexpanded_Parsed_state_predicate_factor_NOT (instantiate_state_factor param_map f) | Unexpanded_Parsed_simple_predicate spred -> Unexpanded_Parsed_simple_predicate (instantiate_simple_predicate param_map spred) | Unexpanded_Parsed_state_predicate pred -> Unexpanded_Parsed_state_predicate (instantiate_state_predicate param_map pred) + (* TODO: merge next 4 branches *) | Unexpanded_Parsed_forall_state_predicate (index_data, pred) -> let binded_name = index_data.forall_index_name in let prev_binded_value_opt = Hashtbl.find_opt param_map binded_name in @@ -600,15 +601,30 @@ and instantiate_state_factor (param_map: var_map): unexpanded_parsed_state_predi | Unexpanded_Parsed_forall_simple_predicate (index_data, spred) -> let binded_name = index_data.forall_index_name in let prev_binded_value_opt = Hashtbl.find_opt param_map binded_name in - (* Here we remove the binde to allow the inner forall to rebind it. *) - (* This means that if the user had something like `forall i in [0, 1]: (forall i in [0, 1]: (p(i)) && q(i))`, the inner one will be - substituted first, and the outer one will substitute only `q(i)` *) Hashtbl.remove param_map binded_name; let r = Unexpanded_Parsed_forall_simple_predicate (index_data, instantiate_simple_predicate param_map spred) in begin match prev_binded_value_opt with | Some v -> Hashtbl.add param_map binded_name v; r | None -> r end + | Unexpanded_Parsed_exists_state_predicate (index_data, pred) -> + let binded_name = index_data.forall_index_name in + let prev_binded_value_opt = Hashtbl.find_opt param_map binded_name in + Hashtbl.remove param_map binded_name; + let r = Unexpanded_Parsed_exists_state_predicate (index_data, instantiate_state_predicate param_map pred) in + begin match prev_binded_value_opt with + | Some v -> Hashtbl.add param_map binded_name v; r + | None -> r + end + | Unexpanded_Parsed_exists_simple_predicate (index_data, spred) -> + let binded_name = index_data.forall_index_name in + let prev_binded_value_opt = Hashtbl.find_opt param_map binded_name in + Hashtbl.remove param_map binded_name; + let r = Unexpanded_Parsed_exists_simple_predicate (index_data, instantiate_simple_predicate param_map spred) in + begin match prev_binded_value_opt with + | Some v -> Hashtbl.add param_map binded_name v; r + | None -> r + end let expand_loc_pred (g_decls: variable_declarations): unexpanded_parsed_loc_predicate -> parsed_loc_predicate = function | Unexpanded_Parsed_loc_predicate_EQ (aut_name, loc) -> Parsed_loc_predicate_EQ (expand_name_or_access g_decls aut_name, loc) @@ -633,7 +649,7 @@ and expand_state_term (g_decls: variable_declarations): unexpanded_parsed_state_ and expand_state_factor (g_decls: variable_declarations) (factor: unexpanded_parsed_state_predicate_factor): parsed_state_predicate_factor = match factor with - (* TODO: Refactor the next two branches *) + (* TODO: merge next 4 branches *) | Unexpanded_Parsed_forall_simple_predicate (index_data, spred) -> begin let indices = indices_from_forall_index_data g_decls index_data in (* Here we reverse the list to preserve the order of the index after the fold *) @@ -686,6 +702,53 @@ and expand_state_factor (g_decls: variable_declarations) (factor: unexpanded_par in List.fold_left fold_fun (Parsed_state_predicate pred_i) is end + | Unexpanded_Parsed_exists_simple_predicate (index_data, spred) -> begin + let indices = indices_from_forall_index_data g_decls index_data in + match List.rev indices with + | [] -> failwith "error or vacuity?" + | i :: is -> + (* Compute the disjunction of spred instantiated with each idx in (i::is) *) + let instantiate_spred_with idx = + let param_map_idx = gen_aux_var_tbl idx index_data in + instantiate_simple_predicate param_map_idx spred |> + expand_simple_predicate g_decls + in + let spred_i = instantiate_spred_with i in + let state_pred_of_simple_pred p = + Parsed_state_predicate_term (Parsed_state_predicate_factor (Parsed_simple_predicate p)) + in + let fold_fun acc idx = + let spred_idx = instantiate_spred_with idx in + let spred_idx_state = state_pred_of_simple_pred spred_idx in + match acc with + | Parsed_state_predicate p -> + Parsed_state_predicate (Parsed_state_predicate_OR (spred_idx_state, p)) + | Parsed_simple_predicate p -> + let p_state= state_pred_of_simple_pred p in + Parsed_state_predicate (Parsed_state_predicate_OR (spred_idx_state, p_state)) + | _ -> failwith "[expand_state_factor]: unreachable" + in + List.fold_left fold_fun (Parsed_simple_predicate spred_i) is + end + | Unexpanded_Parsed_exists_state_predicate (index_data, pred) -> begin + let indices = indices_from_forall_index_data g_decls index_data in + match List.rev indices with + | [] -> failwith "error or vacuity?" + | i :: is -> + let instantiate_pred_with idx = + let param_map_idx = gen_aux_var_tbl idx index_data in + instantiate_state_predicate param_map_idx pred |> + expand_state_predicate g_decls + in + let pred_i = instantiate_pred_with i in + let fold_fun acc idx = + let pred_idx = instantiate_pred_with idx in + match acc with + | Parsed_state_predicate p -> Parsed_state_predicate (Parsed_state_predicate_OR (pred_idx, p)) + | _ -> failwith "[expand_state_factor]: unreachable" + in + List.fold_left fold_fun (Parsed_state_predicate pred_i) is + end | Unexpanded_Parsed_state_predicate_factor_NOT f -> Parsed_state_predicate_factor_NOT (expand_state_factor g_decls f) | Unexpanded_Parsed_simple_predicate spred -> Parsed_simple_predicate (expand_simple_predicate g_decls spred) | Unexpanded_Parsed_state_predicate pred -> Parsed_state_predicate (expand_state_predicate g_decls pred)