From 0a180ed0e962f72027b15ffef6414f1f951841ad Mon Sep 17 00:00:00 2001 From: Yuichi Motoyama Date: Tue, 20 Aug 2024 14:37:22 +0900 Subject: [PATCH] fix policy.write --- physbo/search/discrete/policy.py | 14 +++++++++----- physbo/search/discrete_multi/policy.py | 13 ++++++++----- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/physbo/search/discrete/policy.py b/physbo/search/discrete/policy.py index b09485bb..27d72d7e 100644 --- a/physbo/search/discrete/policy.py +++ b/physbo/search/discrete/policy.py @@ -138,11 +138,15 @@ def write( time_run_simulator=time_run_simulator, ) self.training.add(X=X, t=t, Z=Z) - local_index = np.searchsorted(self.actions, action) - local_index = local_index[ - np.take(self.actions, local_index, mode="clip") == action - ] - self.actions = self._delete_actions(local_index) + + # remove the selected actions from the list of candidates if exists + if len(self.actions) > 0: + local_index = np.searchsorted(self.actions, action) + local_index = local_index[ + np.take(self.actions, local_index, mode="clip") == action + ] + self.actions = self._delete_actions(local_index) + if self.new_data is None: self.new_data = variable(X=X, t=t, Z=Z) else: diff --git a/physbo/search/discrete_multi/policy.py b/physbo/search/discrete_multi/policy.py index 78f14582..40cc35a7 100644 --- a/physbo/search/discrete_multi/policy.py +++ b/physbo/search/discrete_multi/policy.py @@ -106,11 +106,14 @@ def write( else: self.new_data_list[i].add(X=X, t=t[:, i], Z=Z) self.training_list[i].add(X=X, t=t[:, i], Z=Z) - local_index = np.searchsorted(self.actions, action) - local_index = local_index[ - np.take(self.actions, local_index, mode="clip") == action - ] - self.actions = self._delete_actions(local_index) + + # remove action from candidates if exists + if len(self.actions) > 0: + local_index = np.searchsorted(self.actions, action) + local_index = local_index[ + np.take(self.actions, local_index, mode="clip") == action + ] + self.actions = self._delete_actions(local_index) def _model(self, i): training = self.training_list[i]