From 88ba2a4bf593988e36d79eb8dc670bc2a403102b Mon Sep 17 00:00:00 2001 From: Amnon Bleich Date: Mon, 21 Aug 2023 11:39:12 +0200 Subject: [PATCH] bug fix - remove attn.bias keys from GPT state dict in 'from_pretrined'. otherwise assertion fails --- mingpt/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mingpt/model.py b/mingpt/model.py index 83ee22dc..a7692e36 100644 --- a/mingpt/model.py +++ b/mingpt/model.py @@ -187,6 +187,7 @@ def from_pretrained(cls, model_type): config.block_size = 1024 # openai's model block_size model = GPT(config) sd = model.state_dict() + keys_sd = [k for k in sd if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param # init a huggingface/transformers model model_hf = GPT2LMHeadModel.from_pretrained(model_type) @@ -197,7 +198,7 @@ def from_pretrained(cls, model_type): transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. # this means that we have to transpose these weights when we import them - assert len(keys) == len(sd) + assert len(keys) == len(keys_sd) for k in keys: if any(k.endswith(w) for w in transposed): # special treatment for the Conv1D weights we need to transpose