Skip to content

Commit

Permalink
refactor state_dict loading
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Nov 8, 2023
1 parent 3aeba3e commit 589d4f1
Showing 1 changed file with 73 additions and 92 deletions.
165 changes: 73 additions & 92 deletions onmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,58 @@ def update_dropout(self, dropout, attention_dropout):
def count_parameters(self, log=print):
raise NotImplementedError

def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset):

if name.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
col_slice_start = param.data.size(0) * offset
col_slice_end = param.data.size(0) * (offset + 1)
else:
col_slice_start = 0
col_slice_end = param.data.size(0)
if param.data.dim() == 2:
if name.split(".")[-1] in ["final_linear", "w_2"]:
row_slice_start = param.data.size(1) * offset
row_slice_end = param.data.size(1) * (offset + 1)
else:
row_slice_start = 0
row_slice_end = param.data.size(1)
assert (
param.data.size()
== ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].size()
), "An error in model's partition and checkpoint's slice was detected"
if param_name in buf_list:
module.register_buffer(
param_name,
ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
],
)
else:
param.data = ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
]
else:
assert (
param.data.size() == ckpt_t[col_slice_start:col_slice_end].size()
), "An error in model's partition and checkpoint's slice was detected"
if param_name in buf_list:
module.register_buffer(
param_name, ckpt_t[col_slice_start:col_slice_end]
)
else:
param.data = ckpt_t[col_slice_start:col_slice_end]

def load_state_dict(
self,
checkpoint,
Expand All @@ -71,68 +123,34 @@ def load_state_dict(
for name, module in self.named_modules():
for buf_name, buf in module.named_buffers():
buf_list.append(buf_name)
if len(buf_name.split(".")) == 1: # only last key
if precision != torch.int8:
module.to(precision)
module.to(device)
for param_name, param in module.named_parameters():
named_buf_and_param = list(module.named_buffers()) + list(
module.named_parameters()
)
for param_name, param in named_buf_and_param:
if len(param_name.split(".")) == 1: # only last key
if name + "." + param_name in checkpoint["model"].keys():
ckpt_t = checkpoint["model"][name + "." + param_name]

if name.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
col_slice_start = param.data.size(0) * offset
col_slice_end = param.data.size(0) * (offset + 1)
else:
col_slice_start = 0
col_slice_end = param.data.size(0)
if param.data.dim() == 2:
if name.split(".")[-1] in ["final_linear", "w_2"]:
row_slice_start = param.data.size(1) * offset
row_slice_end = param.data.size(1) * (offset + 1)
else:
row_slice_start = 0
row_slice_end = param.data.size(1)
assert (
param.data.size()
== ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].size()
), "An error in model's partition and checkpoint's slice was detected"
param.data = ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
]
else:
assert (
param.data.size()
== ckpt_t[col_slice_start:col_slice_end].size()
), "An error in model's partition and checkpoint's slice was detected"
param.data = ckpt_t[col_slice_start:col_slice_end]

self._load_param(
name, module, param_name, param, buf_list, ckpt_t, offset
)
del checkpoint["model"][name + "." + param_name]
elif (
"generator" in checkpoint.keys()
and name == "generator"
and "generator" in name
and checkpoint["generator"] is not None
and param_name in checkpoint["generator"].keys()
):
param.data = checkpoint["generator"][param_name]
del checkpoint["generator"][param_name]
keyname = name + "." + param_name if "linear" in name else param_name
param.data = checkpoint["generator"][keyname]
del checkpoint["generator"][keyname]
elif strict and "lora" not in param_name:
raise ValueError(
"Missing key in checkpoint: %s" % name + "." + param_name
)
if precision != torch.int8:
module.to(precision)
module.to(device)

for key in checkpoint[
"model"
].keys(): # if some keys are left in checkpoint after deletion
Expand Down Expand Up @@ -182,62 +200,25 @@ def load_safe_state_dict(
f.append(safetensors.safe_open(shard, framework="pt", device="cpu"))
for key in f[i].keys():
keys_shard[key] = i
if device == torch.device("cpu"):
offset = 0
buf_list = []
for name, module in self.named_modules():
for buf_name, buf in module.named_buffers():
buf_list.append(buf_name)
if len(buf_name.split(".")) == 1: # only last key
if precision == torch.int8:
torch.quantization.quantize_dynamic(module, inplace=True)
else:
module.to(precision)
module.to(device)
for param_name, param in module.named_parameters():
named_buf_and_param = list(module.named_buffers()) + list(
module.named_parameters()
)
for param_name, param in named_buf_and_param:
if len(param_name.split(".")) == 1: # only last key
if name + "." + param_name in keys_shard.keys():

ckpt_t = f[keys_shard[name + "." + param_name]].get_tensor(
name + "." + param_name
)
if name.split(".")[-1] in [
"linear_keys",
"linear_values",
"linear_query",
"w_1",
"w_3",
]:
col_slice_start = param.data.size(0) * offset
col_slice_end = param.data.size(0) * (offset + 1)
else:
col_slice_start = 0
col_slice_end = param.data.size(0)
if param.data.dim() == 2:
if name.split(".")[-1] in ["final_linear", "w_2"]:
row_slice_start = param.data.size(1) * offset
row_slice_end = param.data.size(1) * (offset + 1)
else:
row_slice_start = 0
row_slice_end = param.data.size(1)
assert (
param.data.size()
== ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
].size()
), "An error in model's partition and checkpoint's slice was detected"

param.data = ckpt_t[
col_slice_start:col_slice_end,
row_slice_start:row_slice_end,
]
else:
assert (
param.data.size()
== ckpt_t[col_slice_start:col_slice_end].size()
), "An error in model's partition and checkpoint's slice was detected"

param.data = ckpt_t[col_slice_start:col_slice_end]

self._load_param(
name, module, param_name, param, buf_list, ckpt_t, offset
)
keyfound[name + "." + param_name] = True
elif strict and "lora" not in param_name:
raise ValueError(
Expand Down

0 comments on commit 589d4f1

Please sign in to comment.