Skip to content

Commit

Permalink
fix: setting log level in save() (#304)
Browse files Browse the repository at this point in the history
* fix: setting log level in save()

Signed-off-by: Anh Uong <[email protected]>

* add log level doc string to save

Signed-off-by: Anh Uong <[email protected]>

---------

Signed-off-by: Anh Uong <[email protected]>
  • Loading branch information
anhuong authored Aug 16, 2024
1 parent 2d1c17c commit a6d093e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
2 changes: 1 addition & 1 deletion tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no():
# validate that no checkpoints created
assert not any(x.startswith("checkpoint-") for x in os.listdir(tempdir))

sft_trainer.save(tempdir, trainer)
sft_trainer.save(tempdir, trainer, "debug")
assert any(x.endswith(".safetensors") for x in os.listdir(tempdir))
_test_run_inference(checkpoint_path=tempdir)

Expand Down
4 changes: 3 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,15 @@ def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Path to save the model to.
trainer: SFTTrainer
Instance of SFTTrainer used for training to save the model.
log_level: str
Optional threshold to set save save logger to, default warning.
"""
logger = logging.getLogger("sft_trainer_save")
# default value from TrainingArguments
if log_level == "passive":
log_level = "WARNING"

logger.setLevel(log_level)
logger.setLevel(log_level.upper())

if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
Expand Down

0 comments on commit a6d093e

Please sign in to comment.