diff --git a/gym_classification/envs/env_4_rl_classification.py b/gym_classification/envs/env_4_rl_classification.py index 4428669..4deb7fd 100644 --- a/gym_classification/envs/env_4_rl_classification.py +++ b/gym_classification/envs/env_4_rl_classification.py @@ -56,7 +56,7 @@ def init_dataset(self, X=None,y=None,batch_size=None,output_shape=None,randomize def reset(self): - self.episode_over = np.array([False]*len(self.current_indices)) + self.episode_over = False self.true_labels = np.take(self.y,self.current_indices,axis=0).ravel() if self.output_shape: return np.take(self.X,self.current_indices,axis=0).reshape(-1,*self.output_shape) @@ -75,7 +75,7 @@ def step(self, action): last_element = self.current_indices[-1] if(max(self.current_indices) + self.batch_size + 1) > len(self.X): - self.episode_over = np.array([True]*len(self.current_indices)) + self.episode_over = True ## greater if last_element == max(self.current_indices): self.current_indices += self.batch_size