Skip to content

Commit

Permalink
[patch] minor fixes for 0.0.2 (#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez authored Sep 20, 2024
1 parent b5df05c commit 4c3adf1
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 17 deletions.
44 changes: 28 additions & 16 deletions eole/bin/convert/convert_HF.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def run(cls, args):
eos_token = None
optional_eos = []
mapped_tokens = []
gpt2_pretok = False

# ALL THESE IF SHOULD BE HANDLED IN MAPPINGS
if arch == "PhiForCausalLM":
Expand Down Expand Up @@ -940,23 +941,35 @@ def get_weight(checkpoint, tensor_name):
# Not sure if we could do much cleaner to retrieve optional eos tokens
eos_token_id = config.get("eos_token_id", None)
if isinstance(eos_token_id, list):
optional_eos = [
data["added_tokens_decoder"][str(index)]["content"]
for index in eos_token_id[1:]
]
eos_token = optional_eos[0]
if "added_tokens_decoder" in data.keys():
eos_tokens = [
data["added_tokens_decoder"][str(index)]["content"]
for index in eos_token_id[1:]
]
optional_eos = eos_tokens[1:]
eos_token = eos_tokens[0]
elif isinstance(eos_token_id, int):
eos_token = data["added_tokens_decoder"][str(eos_token_id)][
"content"
]
if "eos_token" in data.keys():
if isinstance(data["eos_token"], dict):
# Llama2 style
eos_token = data["eos_token"]["content"]
elif isinstance(data["eos_token"], str):
eos_token = data["eos_token"]
elif "added_tokens_decoder" in data.keys():
eos_token = data["added_tokens_decoder"][str(eos_token_id)][
"content"
]
# Automatically convert added_tokens into mapped_tokens
mapped_tokens = [
(
token["content"],
re.sub(r"<\|([^|]*)\|>", "\uff5f\\1\uff60", token["content"]),
)
for token in data["added_tokens_decoder"].values()
]
if "added_tokens_decoder" in data.keys():
mapped_tokens = [
(
token["content"],
re.sub(
r"<\|([^|]*)\|>", "\uff5f\\1\uff60", token["content"]
),
)
for token in data["added_tokens_decoder"].values()
]
else:
add_bos_token = True

Expand Down Expand Up @@ -1009,7 +1022,6 @@ def get_weight(checkpoint, tensor_name):
with open(tokenizer_json, encoding="utf-8") as f:
data = json.load(f)
# gpt2_pretok
gpt2_pretok = False
pretokenizers = data.get("pre_tokenizer", {}).get("pretokenizers", [{}])
for pretokenizer in pretokenizers:
if pretokenizer.get("type", None) == "ByteLevel":
Expand Down
8 changes: 8 additions & 0 deletions eole/bin/model/lora_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def add_args(cls, parser):
@classmethod
def run(cls, args):
init_logger()
config_path = os.path.join(args.base_model, "config.json")
with open(config_path) as f:
config = json.load(f)
inference_config = config.get("inference", None)
base_checkpoint = load_checkpoint(args.base_model)
lora_checkpoint = load_checkpoint(args.lora_weights)
vocabs = dict_to_vocabs(lora_checkpoint["vocab"])
Expand Down Expand Up @@ -84,6 +88,8 @@ def run(cls, args):
optim = None
model_state_dict = model.state_dict()
new_config = base_checkpoint["config"]
# use compute_dtype from lora finetuning
new_config.training.compute_dtype = config.training.compute_dtype
elif args.action == "concat":
model.half() # We keep FP16 for all
optim = lora_checkpoint["optim"]
Expand All @@ -101,6 +107,8 @@ def run(cls, args):
json.dump(vocab_dict, f, indent=2, ensure_ascii=False)
# save config
config_dict = recursive_model_fields_set(new_config)
if inference_config is not None:
config_dict["inference"] = inference_config
with open(os.path.join(args.output, "config.json"), "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
shards = glob.glob(os.path.join(args.base_model, "model.*.safetensors"))
Expand Down
6 changes: 6 additions & 0 deletions eole/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ class TrainingConfig(
score_threshold: float = Field(
default=0.68, description="Threshold to filterout data"
)
dummy_load: bool | None = Field(
default=False,
description="Ignore some warnings if we are only loading the configuration "
"prior to other operations, e.g. in `train_from` context.",
)

@computed_field
@cached_property
Expand Down Expand Up @@ -316,6 +321,7 @@ def _validate_running_config(self):
torch.cuda.is_available()
and not self.gpu_ranks
and self.model_fields_set != set()
and not self.dummy_load
):
logger.warn("You have a CUDA device, should run with -gpu_ranks")
if self.world_size < len(self.gpu_ranks):
Expand Down
4 changes: 4 additions & 0 deletions eole/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def load_checkpoint(model_path):
# drop inference to prevent validation issues
if "inference" in config_dict.keys():
config_dict.pop("inference")
if "training" in config_dict.keys():
config_dict["training"]["dummy_load"] = True
else:
config_dict["training"] = {"dummy_load": True}
_config = TrainConfig(**config_dict)
checkpoint["config"] = _config
else:
Expand Down
3 changes: 2 additions & 1 deletion eole/modules/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def __init__(self, hidden_size: int, eps: float = 1e-6):
self.weight = nn.Parameter(torch.ones(hidden_size))

def forward(self, hidden_states):
if AWQ_EXT and not self.training:
dtype = next(self.parameters()).dtype
if AWQ_EXT and not self.training and dtype == torch.float16:
inp_type = hidden_states.dtype
output = torch.empty_like(hidden_states).to(inp_type)
if hidden_states.dim() == 2: # patch for multi experts
Expand Down

0 comments on commit 4c3adf1

Please sign in to comment.