Skip to content

Commit

Permalink
Merge pull request #58 from issp-center-dev/fix_57
Browse files Browse the repository at this point in the history
fix `policy.write` method to write the last action
  • Loading branch information
yomichi authored Aug 20, 2024
2 parents 606c76b + 0a180ed commit e9ab4e3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
14 changes: 9 additions & 5 deletions physbo/search/discrete/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,15 @@ def write(
time_run_simulator=time_run_simulator,
)
self.training.add(X=X, t=t, Z=Z)
local_index = np.searchsorted(self.actions, action)
local_index = local_index[
np.take(self.actions, local_index, mode="clip") == action
]
self.actions = self._delete_actions(local_index)

# remove the selected actions from the list of candidates if exists
if len(self.actions) > 0:
local_index = np.searchsorted(self.actions, action)
local_index = local_index[
np.take(self.actions, local_index, mode="clip") == action
]
self.actions = self._delete_actions(local_index)

if self.new_data is None:
self.new_data = variable(X=X, t=t, Z=Z)
else:
Expand Down
13 changes: 8 additions & 5 deletions physbo/search/discrete_multi/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def write(
else:
self.new_data_list[i].add(X=X, t=t[:, i], Z=Z)
self.training_list[i].add(X=X, t=t[:, i], Z=Z)
local_index = np.searchsorted(self.actions, action)
local_index = local_index[
np.take(self.actions, local_index, mode="clip") == action
]
self.actions = self._delete_actions(local_index)

# remove action from candidates if exists
if len(self.actions) > 0:
local_index = np.searchsorted(self.actions, action)
local_index = local_index[
np.take(self.actions, local_index, mode="clip") == action
]
self.actions = self._delete_actions(local_index)

def _model(self, i):
training = self.training_list[i]
Expand Down

0 comments on commit e9ab4e3

Please sign in to comment.