diff --git a/src/lightning_trainable/__init__.py b/src/lightning_trainable/__init__.py index 23e0dc3..f8203f1 100644 --- a/src/lightning_trainable/__init__.py +++ b/src/lightning_trainable/__init__.py @@ -4,5 +4,5 @@ from . import modules from . import trainable -from .trainable import Trainable, TrainableHParams +from .trainable import Trainable, TrainableHParams, SkipBatch from .hparams import HParams diff --git a/src/lightning_trainable/trainable/__init__.py b/src/lightning_trainable/trainable/__init__.py index 369ce7e..ac8b179 100644 --- a/src/lightning_trainable/trainable/__init__.py +++ b/src/lightning_trainable/trainable/__init__.py @@ -1 +1 @@ -from .trainable import Trainable, TrainableHParams +from .trainable import Trainable, TrainableHParams, SkipBatch