From a6d093eb96e62311ffa02cdeb97d78c90fc86e74 Mon Sep 17 00:00:00 2001 From: Anh Uong Date: Fri, 16 Aug 2024 09:47:26 -0600 Subject: [PATCH] fix: setting log level in save() (#304) * fix: setting log level in save() Signed-off-by: Anh Uong * add log level doc string to save Signed-off-by: Anh Uong --------- Signed-off-by: Anh Uong --- tests/test_sft_trainer.py | 2 +- tuning/sft_trainer.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 0264d3b3..fc7ab144 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -468,7 +468,7 @@ def test_run_causallm_ft_save_with_save_model_dir_save_strategy_no(): # validate that no checkpoints created assert not any(x.startswith("checkpoint-") for x in os.listdir(tempdir)) - sft_trainer.save(tempdir, trainer) + sft_trainer.save(tempdir, trainer, "debug") assert any(x.endswith(".safetensors") for x in os.listdir(tempdir)) _test_run_inference(checkpoint_path=tempdir) diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 45fea7ca..b9ea0027 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -373,13 +373,15 @@ def save(path: str, trainer: SFTTrainer, log_level="WARNING"): Path to save the model to. trainer: SFTTrainer Instance of SFTTrainer used for training to save the model. + log_level: str + Optional threshold to set save save logger to, default warning. """ logger = logging.getLogger("sft_trainer_save") # default value from TrainingArguments if log_level == "passive": log_level = "WARNING" - logger.setLevel(log_level) + logger.setLevel(log_level.upper()) if not os.path.exists(path): os.makedirs(path, exist_ok=True)