From e532adcc75fed4c03c8f2c1d8ad3954d7a47960f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Guillermo=20Caminero=20Fern=C3=A1ndez?= Date: Thu, 20 Jan 2022 09:59:12 +0100 Subject: [PATCH] Change episode over to Bool --- gym_classification/envs/env_4_rl_classification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gym_classification/envs/env_4_rl_classification.py b/gym_classification/envs/env_4_rl_classification.py index 4428669..4deb7fd 100644 --- a/gym_classification/envs/env_4_rl_classification.py +++ b/gym_classification/envs/env_4_rl_classification.py @@ -56,7 +56,7 @@ def init_dataset(self, X=None,y=None,batch_size=None,output_shape=None,randomize def reset(self): - self.episode_over = np.array([False]*len(self.current_indices)) + self.episode_over = False self.true_labels = np.take(self.y,self.current_indices,axis=0).ravel() if self.output_shape: return np.take(self.X,self.current_indices,axis=0).reshape(-1,*self.output_shape) @@ -75,7 +75,7 @@ def step(self, action): last_element = self.current_indices[-1] if(max(self.current_indices) + self.batch_size + 1) > len(self.X): - self.episode_over = np.array([True]*len(self.current_indices)) + self.episode_over = True ## greater if last_element == max(self.current_indices): self.current_indices += self.batch_size