diff --git a/physbo/search/discrete/policy.py b/physbo/search/discrete/policy.py index 27d72d7e..250f59b7 100644 --- a/physbo/search/discrete/policy.py +++ b/physbo/search/discrete/policy.py @@ -694,7 +694,13 @@ def load(self, file_history, file_training=None, file_predictor=None): self.predictor = pickle.load(f) N = self.history.total_num_search - self.actions = self._delete_actions(self.history.chosen_actions[:N]) + + visited = self.history.chosen_actions[:N] + local_index = np.searchsorted(self.actions, visited) + local_index = local_index[ + np.take(self.actions, local_index, mode="clip") == visited + ] + self.actions = self._delete_actions(local_index) def export_predictor(self): """ diff --git a/physbo/search/discrete_multi/policy.py b/physbo/search/discrete_multi/policy.py index 40cc35a7..877201f4 100644 --- a/physbo/search/discrete_multi/policy.py +++ b/physbo/search/discrete_multi/policy.py @@ -490,7 +490,13 @@ def load(self, file_history, file_training_list=None, file_predictor_list=None): self.load_predictor_list(file_predictor_list) N = self.history.total_num_search - self.actions = self._delete_actions(self.history.chosen_actions[:N]) + + visited = self.history.chosen_actions[:N] + local_index = np.searchsorted(self.actions, visited) + local_index = local_index[ + np.take(self.actions, local_index, mode="clip") == visited + ] + self.actions = self._delete_actions(local_index) def save_predictor_list(self, file_name): with open(file_name, "wb") as f: