diff --git a/data.py b/data.py index 56ecf95..81abada 100644 --- a/data.py +++ b/data.py @@ -100,9 +100,9 @@ def get_class_one_hot(self, class_str): # Now one-hot it. label_hot = to_categorical(label_encoded, len(self.classes)) - assert len(label_hot) == len(self.classes) + assert len(label_hot[0]) == len(self.classes) - return label_hot + return label_hot[0] def split_train_test(self): """Split the data into train and test groups."""