diff --git a/tsgm/models/augmentations.py b/tsgm/models/augmentations.py index 6612a85..8daba28 100644 --- a/tsgm/models/augmentations.py +++ b/tsgm/models/augmentations.py @@ -169,6 +169,8 @@ def generate(self, X: TensorLike, y: Optional[TensorLike] = None, n_samples: int slices.append(s) slices.append(sequence[start_idx:]) np.random.shuffle(slices) + # concatenate the slices + sequence = np.concatenate(slices) synthetic_data.append(sequence) if has_labels: new_labels.append(y[i])