Skip to content

Commit

Permalink
feat: add save_model_dir flag where final checkpoint saved (#291)
Browse files Browse the repository at this point in the history
* add save_model_dir flag for final checkpoint

Signed-off-by: Anh-Uong <[email protected]>

* remove output_dir logic, add save method

Signed-off-by: Anh-Uong <[email protected]>

* update accelerate_launch, remove save tokenizer

Signed-off-by: Anh-Uong <[email protected]>

* fix: put back creation of .complete file

Signed-off-by: Anh-Uong <[email protected]>

* fix failing tests and add new ones

Signed-off-by: Anh-Uong <[email protected]>

* tests: add sft_trainer test to train and save

- small refactor of tests

Signed-off-by: Anh-Uong <[email protected]>

* add docs on saving checkpoints and fix help msg

Signed-off-by: Anh-Uong <[email protected]>

* update example and note best checkpoint

Signed-off-by: Anh-Uong <[email protected]>

* changes based on PR review

Signed-off-by: Anh-Uong <[email protected]>

* add logging to save, fix error out properly

Signed-off-by: Anh-Uong <[email protected]>

---------

Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong authored Aug 14, 2024
1 parent 0aae2aa commit 78909af
Show file tree
Hide file tree
Showing 8 changed files with 343 additions and 158 deletions.
45 changes: 45 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
- [Training](#training)
- [Single GPU](#single-gpu)
- [Multiple GPUs with FSDP](#multiple-gpus-with-fsdp)
- [Tips on Parameters to Set](#tips-on-parameters-to-set)
- [Tuning Techniques](#tuning-techniques)
- [LoRA Tuning Example](#lora-tuning-example)
- [Prompt Tuning](#prompt-tuning)
Expand Down Expand Up @@ -225,6 +226,50 @@ tuning/sft_trainer.py \

To summarize you can pick either python for single-GPU jobs or use accelerate launch for multi-GPU jobs. The following tuning techniques can be applied:

### Tips on Parameters to Set

#### Saving checkpoints while training

By default, [`save_strategy`](tuning/config/configs.py) is set to `"epoch"` in the TrainingArguments. This means that checkpoints will be saved on each epoch. This can also be set to `"steps"` to save on every `"save_steps"` or `"no"` to not save any checkpoints.

Checkpoints are saved to the given `output_dir`, which is a required field. If `save_strategy="no"`, the `output_dir` will only contain the training logs with loss details.

A useful flag to set to limit the number of checkpoints saved is [`save_total_limit`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments.save_total_limit). Older checkpoints are deleted from the `output_dir` to limit the number of checkpoints, for example, if `save_total_limit=1`, this will only save the last checkpoint. However, while tuning, two checkpoints will exist in `output_dir` for a short time as the new checkpoint is created and then the older one will be deleted. If the user sets a validation dataset and [`load_best_model_at_end`](https://huggingface.co/docs/transformers/en/main_classes/trainer#transformers.TrainingArguments.load_best_model_at_end), then the best checkpoint will be saved.

#### Saving model after training

`save_model_dir` can optionally be set to save the tuned model using `SFTTrainer.save_model()`. This can be used in tandem with `save_strategy="no"` to only save the designated checkpoint and not any intermediate checkpoints, which can help to save space.

`save_model_dir` can be set to a different directory than `output_dir`. If set to the same directory, the designated checkpoint, training logs, and any intermediate checkpoints will all be saved to the same directory as seen below.

<details>
<summary>Ways you can use `save_model_dir` and more tips:</summary>

For example, if `save_model_dir` is set to a sub-directory of `output_dir`and `save_total_limit=1` with LoRA tuning, the directory would look like:

```sh
$ ls /tmp/output_dir/
checkpoint-35 save_model_dir training_logs.jsonl

$ ls /tmp/output_dir/save_model_dir/
README.md adapter_model.safetensors special_tokens_map.json tokenizer.model training_args.bin
adapter_config.json added_tokens.json tokenizer.json tokenizer_config.json
```

Here is an fine tuning example of how the directory would look if `output_dir` is set to the same value as `save_model_dir` and `save_total_limit=2`. Note the checkpoint directories as well as the `training_logs.jsonl`:

```sh
$ ls /tmp/same_dir

added_tokens.json model-00001-of-00006.safetensors model-00006-of-00006.safetensors tokenizer_config.json
checkpoint-16 model-00002-of-00006.safetensors model.safetensors.index.json training_args.bin
checkpoint-20 model-00003-of-00006.safetensors special_tokens_map.json training_logs.jsonl
config.json model-00004-of-00006.safetensors tokenizer.json
generation_config.json model-00005-of-00006.safetensors tokenizer.model
```

</details>

## Tuning Techniques:

### LoRA Tuning Example
Expand Down
226 changes: 95 additions & 131 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import subprocess
import sys
import traceback
import tempfile
import shutil
from pathlib import Path
import json

Expand All @@ -37,12 +35,9 @@
# Local
from build.utils import (
process_accelerate_launch_args,
serialize_args,
get_highest_checkpoint,
copy_checkpoint,
)
from tuning.utils.config_utils import get_json_config
from tuning.config.tracker_configs import FileLoggingTrackerConfig
from tuning.utils.error_logging import (
write_termination_log,
USER_ERROR_EXIT_CODE,
Expand Down Expand Up @@ -111,142 +106,111 @@ def main():
# Launch training
#
##########
original_output_dir = job_config.get("output_dir")
with tempfile.TemporaryDirectory() as tempdir:
try:
# checkpoints outputted to tempdir, only final checkpoint copied to output dir
job_config["output_dir"] = tempdir
updated_args = serialize_args(job_config)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = updated_args
launch_command(args)
except subprocess.CalledProcessError as e:
# If the subprocess throws an exception, the base exception is hidden in the
# subprocess call and is difficult to access at this level. However, that is not
# an issue because sft_trainer.py would have already written the exception
# message to termination log.
logging.error(traceback.format_exc())
# The exit code that sft_trainer.py threw is captured in e.returncode

return_code = e.returncode
if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]:
return_code = INTERNAL_ERROR_EXIT_CODE
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(return_code)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
output_dir = job_config.get("output_dir")
try:
# checkpoints outputted to tempdir, only final checkpoint copied to output dir
launch_command(args)
except subprocess.CalledProcessError as e:
# If the subprocess throws an exception, the base exception is hidden in the
# subprocess call and is difficult to access at this level. However, that is not
# an issue because sft_trainer.py would have already written the exception
# message to termination log.
logging.error(traceback.format_exc())
# The exit code that sft_trainer.py threw is captured in e.returncode

return_code = e.returncode
if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]:
return_code = INTERNAL_ERROR_EXIT_CODE
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)
sys.exit(return_code)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

try:
last_checkpoint_dir = get_highest_checkpoint(tempdir)
last_checkpoint_path = os.path.join(tempdir, last_checkpoint_dir)
# remove lm_head from granite with llama arch models
try:
checkpoint_dir = job_config.get("save_model_dir")
if not checkpoint_dir:
checkpoint_dir = os.path.join(
output_dir, get_highest_checkpoint(output_dir)
)

use_flash_attn = job_config.get("use_flash_attn", True)
adapter_config_path = os.path.join(checkpoint_dir, "adapter_config.json")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

use_flash_attn = job_config.get("use_flash_attn", True)
adapter_config_path = os.path.join(
last_checkpoint_path, "adapter_config.json"
if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(adapter_config_path)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
tokenizer = AutoTokenizer.from_pretrained(last_checkpoint_path)

if os.path.exists(adapter_config_path):
base_model_path = get_base_model_from_adapter_config(
adapter_config_path
)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
)
model = PeftModel.from_pretrained(
base_model,
checkpoint_dir,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
else:
model = AutoModelForCausalLM.from_pretrained(
checkpoint_dir,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)

# since the peft library (PEFTModelForCausalLM) does not handle cases
# where the model's layers are modified, in our case the embedding layer
# is modified, so we resize the backbone model's embedding layer with our own
# utility before passing it along to load the PEFT model.
tokenizer_data_utils.tokenizer_and_embedding_resize(
{}, tokenizer=tokenizer, model=base_model
model_arch = model.config.model_type
# check that it is a granite model with llama architecture with tied weights
# ie. lm_head is duplicate of embeddings

# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
if model_arch == "llama" and hasattr(model, "lm_head"):
if (
# lora tuned model has an addt model layer
(
hasattr(model.model, "model")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
model = PeftModel.from_pretrained(
base_model,
last_checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
)
else:
model = AutoModelForCausalLM.from_pretrained(
last_checkpoint_path,
attn_implementation="flash_attention_2" if use_flash_attn else None,
torch_dtype=bfloat16 if use_flash_attn else None,
# prompt tuned model or fine tuned model
or (
hasattr(model.model, "embed_tokens")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
):

model_arch = model.config.model_type
# check that it is a granite model with llama architecture with tied weights
# ie. lm_head is duplicate of embeddings

# a fine tuned model will have params_dict.get("model.embed_tokens.weight")
# a prompt adapter has params_dict.get("base_model.model.embed_tokens.weight")
# a lora adapter has params_dict.get("base_model.model.model.embed_tokens.weight")
copy_checkpoint_bool = True
if model_arch == "llama" and hasattr(model, "lm_head"):
if (
# lora tuned model has an addt model layer
(
hasattr(model.model, "model")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
# prompt tuned model or fine tuned model
or (
hasattr(model.model, "embed_tokens")
and model.lm_head.weight.untyped_storage().data_ptr()
== model.model.embed_tokens.weight.untyped_storage().data_ptr()
)
):

copy_checkpoint_bool = False
logging.info("Removing lm_head from checkpoint")
del model.lm_head.weight

if hasattr(model, "lm_head.weight"):
logging.warning("Failed to delete lm_head.weight from model")

logging.info("Saving checkpoint to %s", original_output_dir)
model.save_pretrained(original_output_dir)
# save tokenizer with model
tokenizer.save_pretrained(original_output_dir)

# copy last checkpoint into mounted output dir
if copy_checkpoint_bool:
logging.info(
"Copying last checkpoint %s into output dir %s",
last_checkpoint_dir,
original_output_dir,
)
copy_checkpoint(last_checkpoint_path, original_output_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered writing output model to storage: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)
logging.info("Removing lm_head from checkpoint")
del model.lm_head.weight

# copy over any loss logs
try:
train_logs_filepath = os.path.join(
tempdir,
FileLoggingTrackerConfig.training_logs_filename,
)
if os.path.exists(train_logs_filepath):
shutil.copy(train_logs_filepath, original_output_dir)

# The .complete file will signal to users that we are finished copying
# files over
if os.path.exists(original_output_dir):
Path(os.path.join(original_output_dir, ".complete")).touch()
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
f"Exception encountered in capturing training logs: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)
if hasattr(model, "lm_head.weight"):
logging.warning("Failed to delete lm_head.weight from model")

logging.info("Saving checkpoint to %s", output_dir)
model.save_pretrained(checkpoint_dir)
# save tokenizer with model
tokenizer.save_pretrained(checkpoint_dir)

except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(f"Exception encountered removing lm_head from model: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

# The .complete file will signal to users that we are finished copying
# files over
if os.path.exists(output_dir):
Path(os.path.join(output_dir, ".complete")).touch()

return 0

Expand Down
Loading

0 comments on commit 78909af

Please sign in to comment.