diff --git a/onmt/models/model.py b/onmt/models/model.py index ca10b0cc04..774b79a97c 100644 --- a/onmt/models/model.py +++ b/onmt/models/model.py @@ -45,6 +45,58 @@ def update_dropout(self, dropout, attention_dropout): def count_parameters(self, log=print): raise NotImplementedError + def _load_param(self, name, module, param_name, param, buf_list, ckpt_t, offset): + + if name.split(".")[-1] in [ + "linear_keys", + "linear_values", + "linear_query", + "w_1", + "w_3", + ]: + col_slice_start = param.data.size(0) * offset + col_slice_end = param.data.size(0) * (offset + 1) + else: + col_slice_start = 0 + col_slice_end = param.data.size(0) + if param.data.dim() == 2: + if name.split(".")[-1] in ["final_linear", "w_2"]: + row_slice_start = param.data.size(1) * offset + row_slice_end = param.data.size(1) * (offset + 1) + else: + row_slice_start = 0 + row_slice_end = param.data.size(1) + assert ( + param.data.size() + == ckpt_t[ + col_slice_start:col_slice_end, + row_slice_start:row_slice_end, + ].size() + ), "An error in model's partition and checkpoint's slice was detected" + if param_name in buf_list: + module.register_buffer( + param_name, + ckpt_t[ + col_slice_start:col_slice_end, + row_slice_start:row_slice_end, + ], + ) + else: + param.data = ckpt_t[ + col_slice_start:col_slice_end, + row_slice_start:row_slice_end, + ] + else: + assert ( + param.data.size() == ckpt_t[col_slice_start:col_slice_end].size() + ), "An error in model's partition and checkpoint's slice was detected" + if param_name in buf_list: + module.register_buffer( + param_name, ckpt_t[col_slice_start:col_slice_end] + ) + else: + param.data = ckpt_t[col_slice_start:col_slice_end] + def load_state_dict( self, checkpoint, @@ -71,61 +123,26 @@ def load_state_dict( for name, module in self.named_modules(): for buf_name, buf in module.named_buffers(): buf_list.append(buf_name) - if len(buf_name.split(".")) == 1: # only last key - if precision != torch.int8: - module.to(precision) - module.to(device) - for param_name, param in module.named_parameters(): + named_buf_and_param = list(module.named_buffers()) + list( + module.named_parameters() + ) + for param_name, param in named_buf_and_param: if len(param_name.split(".")) == 1: # only last key if name + "." + param_name in checkpoint["model"].keys(): ckpt_t = checkpoint["model"][name + "." + param_name] - - if name.split(".")[-1] in [ - "linear_keys", - "linear_values", - "linear_query", - "w_1", - "w_3", - ]: - col_slice_start = param.data.size(0) * offset - col_slice_end = param.data.size(0) * (offset + 1) - else: - col_slice_start = 0 - col_slice_end = param.data.size(0) - if param.data.dim() == 2: - if name.split(".")[-1] in ["final_linear", "w_2"]: - row_slice_start = param.data.size(1) * offset - row_slice_end = param.data.size(1) * (offset + 1) - else: - row_slice_start = 0 - row_slice_end = param.data.size(1) - assert ( - param.data.size() - == ckpt_t[ - col_slice_start:col_slice_end, - row_slice_start:row_slice_end, - ].size() - ), "An error in model's partition and checkpoint's slice was detected" - param.data = ckpt_t[ - col_slice_start:col_slice_end, - row_slice_start:row_slice_end, - ] - else: - assert ( - param.data.size() - == ckpt_t[col_slice_start:col_slice_end].size() - ), "An error in model's partition and checkpoint's slice was detected" - param.data = ckpt_t[col_slice_start:col_slice_end] - + self._load_param( + name, module, param_name, param, buf_list, ckpt_t, offset + ) del checkpoint["model"][name + "." + param_name] elif ( "generator" in checkpoint.keys() - and name == "generator" + and "generator" in name and checkpoint["generator"] is not None and param_name in checkpoint["generator"].keys() ): - param.data = checkpoint["generator"][param_name] - del checkpoint["generator"][param_name] + keyname = name + "." + param_name if "linear" in name else param_name + param.data = checkpoint["generator"][keyname] + del checkpoint["generator"][keyname] elif strict and "lora" not in param_name: raise ValueError( "Missing key in checkpoint: %s" % name + "." + param_name @@ -133,6 +150,7 @@ def load_state_dict( if precision != torch.int8: module.to(precision) module.to(device) + for key in checkpoint[ "model" ].keys(): # if some keys are left in checkpoint after deletion @@ -182,62 +200,25 @@ def load_safe_state_dict( f.append(safetensors.safe_open(shard, framework="pt", device="cpu")) for key in f[i].keys(): keys_shard[key] = i + if device == torch.device("cpu"): + offset = 0 buf_list = [] for name, module in self.named_modules(): for buf_name, buf in module.named_buffers(): buf_list.append(buf_name) - if len(buf_name.split(".")) == 1: # only last key - if precision == torch.int8: - torch.quantization.quantize_dynamic(module, inplace=True) - else: - module.to(precision) - module.to(device) - for param_name, param in module.named_parameters(): + named_buf_and_param = list(module.named_buffers()) + list( + module.named_parameters() + ) + for param_name, param in named_buf_and_param: if len(param_name.split(".")) == 1: # only last key if name + "." + param_name in keys_shard.keys(): ckpt_t = f[keys_shard[name + "." + param_name]].get_tensor( name + "." + param_name ) - if name.split(".")[-1] in [ - "linear_keys", - "linear_values", - "linear_query", - "w_1", - "w_3", - ]: - col_slice_start = param.data.size(0) * offset - col_slice_end = param.data.size(0) * (offset + 1) - else: - col_slice_start = 0 - col_slice_end = param.data.size(0) - if param.data.dim() == 2: - if name.split(".")[-1] in ["final_linear", "w_2"]: - row_slice_start = param.data.size(1) * offset - row_slice_end = param.data.size(1) * (offset + 1) - else: - row_slice_start = 0 - row_slice_end = param.data.size(1) - assert ( - param.data.size() - == ckpt_t[ - col_slice_start:col_slice_end, - row_slice_start:row_slice_end, - ].size() - ), "An error in model's partition and checkpoint's slice was detected" - - param.data = ckpt_t[ - col_slice_start:col_slice_end, - row_slice_start:row_slice_end, - ] - else: - assert ( - param.data.size() - == ckpt_t[col_slice_start:col_slice_end].size() - ), "An error in model's partition and checkpoint's slice was detected" - - param.data = ckpt_t[col_slice_start:col_slice_end] - + self._load_param( + name, module, param_name, param, buf_list, ckpt_t, offset + ) keyfound[name + "." + param_name] = True elif strict and "lora" not in param_name: raise ValueError(