Skip to content

Commit

Permalink
Pass trainer_kwargs to launcher.fit, allow updating dicts in parse_co…
Browse files Browse the repository at this point in the history
…nfig_dict
  • Loading branch information
fdraxler committed Sep 8, 2023
1 parent 96ef4a1 commit 91b59a8
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion src/lightning_trainable/launcher/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
from time import sleep

import yaml
from lightning import LightningModule
from lightning.pytorch.loggers import TensorBoardLogger

Expand All @@ -29,6 +30,8 @@ def main(args=None):
"You must specify at least a `model` config argument. "
"All other config arguments overwrite the values in the stored checkpoint."
)
parser.add_argument("--trainer-kwargs", type=yaml.safe_load, default={},
help="Pass kwargs to the trainer.")
parser.add_argument("--loose-load-state-dict", action="store_true", default=False,
help="When loading a state dict, set `strict`=False")
parser.add_argument("--gradient-regex", type=str, default=None,
Expand Down Expand Up @@ -139,7 +142,7 @@ def main(args=None):
print(f"Deactivated {deactivated_parameters} parameters, {remaining_parameters} parameters left as is.")

# Fit the model
return model.fit(logger_kwargs=logger_kwargs, fit_kwargs=fit_kwargs)
return model.fit(logger_kwargs=logger_kwargs, fit_kwargs=fit_kwargs, trainer_kwargs=args.trainer_kwargs)


if __name__ == '__main__':
Expand Down
5 changes: 4 additions & 1 deletion src/lightning_trainable/launcher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def dict_list_set(dl: dict | list, item, value):
else:
dl[int(item)] = value
else:
dl[item] = value
if item == "!":
dl.update(value)
else:
dl[item] = value


def send_telegram_message(message: str, token: str, chats: List[int]):
Expand Down

0 comments on commit 91b59a8

Please sign in to comment.