diff --git a/data.py b/data.py index e6ba0c4..9ada051 100644 --- a/data.py +++ b/data.py @@ -100,6 +100,8 @@ def get_class_one_hot(self, class_str): # Now one-hot it. label_hot = to_categorical(label_encoded, len(self.classes)) + assert len(label_hot) == len(self.classes) + return label_hot def split_train_test(self):