Skip to content

Commit

Permalink
get_transforms_cls after update_config_with_checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed Sep 23, 2024
1 parent d1a9af5 commit 6000008
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
3 changes: 2 additions & 1 deletion eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@ def load_checkpoint(model_path):
if not os.path.exists(model_path):
raise FileNotFoundError(f"{model_path} does not seem to exist.")
elif os.path.isdir(model_path):
os.environ["MODEL_PATH"] = model_path
logger.info("Loading checkpoint from %s" % model_path)
# checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
checkpoint = {}
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
with open(config_path) as f:
config_dict = json.load(f)
config_dict = json.loads(os.path.expandvars(f.read()))
# drop data to prevent validation issues
config_dict["data"] = {}
# drop inference to prevent validation issues
Expand Down
3 changes: 1 addition & 2 deletions eole/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def _init_train(config):
- resume training but transforms have changed
- resume training but vocab file has been modified
"""
transforms_cls = get_transforms_cls(config._all_transform)

if config.training.train_from:
checkpoint = load_checkpoint(config.training.train_from)

Expand Down Expand Up @@ -93,6 +91,7 @@ def _init_train(config):
checkpoint = None

config = update_config_with_checkpoint(config, checkpoint=checkpoint)
transforms_cls = get_transforms_cls(config._all_transform)
vocabs, transforms = prepare_transforms_vocabs(config, transforms_cls)
if config.training.train_from and not config.training.update_vocab:
logger.info("Keeping checkpoint vocabulary")
Expand Down
16 changes: 8 additions & 8 deletions recipes/llama2/llama-finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ overwrite: true
report_every: 10

# transforms config
transforms: [sentencepiece, filtertoolong]
transforms_configs:
sentencepiece:
src_subword_model: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf/tokenizer.model"
tgt_subword_model: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf/tokenizer.model"
filtertoolong:
src_seq_length: 896
tgt_seq_length: 896
# transforms: [sentencepiece, filtertoolong]
# transforms_configs:
# sentencepiece:
# src_subword_model: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf/tokenizer.model"
# tgt_subword_model: "${EOLE_MODEL_DIR}/llama2-7b-chat-hf/tokenizer.model"
# filtertoolong:
# src_seq_length: 896
# tgt_seq_length: 896

# datasets
data:
Expand Down

0 comments on commit 6000008

Please sign in to comment.