Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Lars Kuehmichel committed Mar 23, 2023
1 parent 71d8903 commit 99c7d62
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/lightning_trainable/modules/dense_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(self, hparams: DenseModuleHParams | dict):
def forward(self, batch: torch.Tensor) -> torch.Tensor:
return self.network(batch)

def configure_network(self):
def configure_network(self) -> nn.Module:
widths = [self.hparams.inputs, *self.hparams.layer_widths, self.hparams.outputs]

layers = []
Expand Down
84 changes: 67 additions & 17 deletions src/lightning_trainable/trainable/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,19 @@ def __init_subclass__(cls, **kwargs):
cls.hparams_type = cls.__annotations__.get("hparams", TrainableHParams)

def compute_metrics(self, batch, batch_idx) -> dict:
""" Compute any relevant metrics, including the loss, on the given batch """
"""
Compute any relevant metrics, including the loss, on the given batch.
You should return a dict in the style of {metric_name: metric_value} from this method,
where metric_value is scalar. The loss as defined by your hparams must also be a key
of this dictionary.
You *must* override this method and you *must* return the loss as defined by your hparams
within the dictionary, if you want to use :func:`trainable.Trainable.fit`
@param batch: The batch to compute metrics on. Usually a Tensor or a tuple of Tensors.
@param batch_idx: Index of this batch.
@return: Dictionary containing metrics to log and the loss to perform backpropagation on.
"""
raise NotImplementedError

def training_step(self, batch, batch_idx):
Expand Down Expand Up @@ -75,11 +87,17 @@ def test_step(self, batch, batch_idx):

def configure_lr_schedulers(self, optimizer):
"""
Configure LR Schedulers for Lightning
Configure the LR Scheduler as defined in HParams.
By default, we only use a single LR Scheduler, attached to a single optimizer.
You can use a ChainedScheduler if you need multiple LR Schedulers throughout training,
or override this method if you need different schedulers for different parameters.
@param optimizer: The optimizer to attach the scheduler to.
@return: The LR Scheduler object.
"""
match self.hparams.lr_scheduler:
case str() as name:
match name:
match name.lower():
case "onecyclelr":
kwargs = dict(
max_lr=optimizer.defaults["lr"],
Expand Down Expand Up @@ -126,7 +144,11 @@ def configure_lr_schedulers(self, optimizer):

def configure_optimizers(self):
"""
Configure optimizers for Lightning
Configure Optimizer and LR Scheduler objects as defined in HParams.
By default, we only use a single optimizer and an optional LR Scheduler.
If you need multiple optimizers, override this method.
@return: A dictionary containing the optimizer and lr_scheduler.
"""
kwargs = dict()

Expand Down Expand Up @@ -155,9 +177,13 @@ def configure_optimizers(self):
lr_scheduler=lr_scheduler,
)

def configure_callbacks(self):
def configure_callbacks(self) -> list:
"""
Configure and return train callbacks for Lightning
Configure train callbacks used by the Lightning Trainer in module fitting.
We provide some useful defaults here, but you may opt to override this method if you want different
callbacks. Callbacks defined here override those provided directly to the Lightning Trainer object.
@return: A list of train callbacks.
"""
if self.val_data is None:
monitor = f"training/{self.hparams.loss}"
Expand All @@ -174,9 +200,11 @@ def configure_callbacks(self):
EpochProgressBar(),
]

def train_dataloader(self):
def train_dataloader(self) -> DataLoader | None:
"""
Configure and return the train dataloader
Configures the Train DataLoader for Lightning. Uses the dataset you passed as train_data.
@return: The DataLoader Object.
"""
if self.train_data is None:
return None
Expand All @@ -188,9 +216,11 @@ def train_dataloader(self):
num_workers=self.hparams.num_workers,
)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader | None:
"""
Configure and return the validation dataloader
Configures the Validation DataLoader for Lightning. Uses the dataset you passed as val_data.
@return: The DataLoader Object.
"""
if self.val_data is None:
return None
Expand All @@ -202,9 +232,11 @@ def val_dataloader(self):
num_workers=self.hparams.num_workers,
)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader | None:
"""
Configure and return the test dataloader
Configures the Test DataLoader for Lightning. Uses the dataset you passed as test_data.
@return: The DataLoader Object.
"""
if self.test_data is None:
return None
Expand All @@ -216,19 +248,29 @@ def test_dataloader(self):
num_workers=self.hparams.num_workers,
)

def configure_logger(self, **kwargs):
def configure_logger(self, **kwargs) -> lightning.loggers.Logger:
"""
Configure and return the Logger to be used by the Lightning.Trainer
Instantiate the Logger used by the Trainer in module fitting.
By default, we use a TensorBoardLogger, but you can use any other logger of your choice.
@param kwargs: Keyword-Arguments to the Logger.
@return: The Logger object.
"""
kwargs.setdefault("save_dir", os.getcwd())
return TensorBoardLogger(
default_hp_metric=False,
**kwargs
)

def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = None):
def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = None) -> lightning.Trainer:
"""
Configure and return the Trainer used to train this module
Instantiate the Lightning Trainer used to train this module.
@param logger_kwargs: Keyword-Arguments to the Logger.
See also :func:`~trainable.Trainable.configure_logger`.
@param trainer_kwargs: Keyword-Arguments to the Lightning Trainer.
See also :func:`~trainable.Trainable.configure_trainer`.
@return: The Lightning Trainer object.
"""
if logger_kwargs is None:
logger_kwargs = dict()
Expand All @@ -254,7 +296,15 @@ def configure_trainer(self, logger_kwargs: dict = None, trainer_kwargs: dict = N

@torch.enable_grad()
def fit(self, logger_kwargs: dict = None, trainer_kwargs: dict = None) -> dict:
""" Fit the module to data and return validation metrics """
"""
Instantiate a Lightning Trainer and use it to fit the module to data.
@param logger_kwargs: Keyword-Arguments to the Logger.
See also :func:`~trainable.Trainable.configure_logger`.
@param trainer_kwargs: Keyword-Arguments to the Lightning Trainer.
See also :func:`~trainable.Trainable.configure_trainer`.
@return: Validation Metrics as defined in :func:`~trainable.Trainable.compute_metrics`.
"""
if logger_kwargs is None:
logger_kwargs = dict()
if trainer_kwargs is None:
Expand Down

0 comments on commit 99c7d62

Please sign in to comment.