From 2a033370125f2f165c9e4a9b1f85ef6bf5a5dd22 Mon Sep 17 00:00:00 2001 From: Yuichi Motoyama Date: Fri, 30 Aug 2024 15:23:25 +0900 Subject: [PATCH] fixed policy.load for MPI mode --- physbo/search/discrete/policy.py | 8 +++++++- physbo/search/discrete_multi/policy.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/physbo/search/discrete/policy.py b/physbo/search/discrete/policy.py index 27d72d7e..250f59b7 100644 --- a/physbo/search/discrete/policy.py +++ b/physbo/search/discrete/policy.py @@ -694,7 +694,13 @@ def load(self, file_history, file_training=None, file_predictor=None): self.predictor = pickle.load(f) N = self.history.total_num_search - self.actions = self._delete_actions(self.history.chosen_actions[:N]) + + visited = self.history.chosen_actions[:N] + local_index = np.searchsorted(self.actions, visited) + local_index = local_index[ + np.take(self.actions, local_index, mode="clip") == visited + ] + self.actions = self._delete_actions(local_index) def export_predictor(self): """ diff --git a/physbo/search/discrete_multi/policy.py b/physbo/search/discrete_multi/policy.py index 40cc35a7..877201f4 100644 --- a/physbo/search/discrete_multi/policy.py +++ b/physbo/search/discrete_multi/policy.py @@ -490,7 +490,13 @@ def load(self, file_history, file_training_list=None, file_predictor_list=None): self.load_predictor_list(file_predictor_list) N = self.history.total_num_search - self.actions = self._delete_actions(self.history.chosen_actions[:N]) + + visited = self.history.chosen_actions[:N] + local_index = np.searchsorted(self.actions, visited) + local_index = local_index[ + np.take(self.actions, local_index, mode="clip") == visited + ] + self.actions = self._delete_actions(local_index) def save_predictor_list(self, file_name): with open(file_name, "wb") as f: