Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load GPT-J from HF #39

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 14 additions & 27 deletions magma/language_model.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,32 @@
import torch
from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel
from transformers import AutoModelForCausalLM, GPTJForCausalLM, GPTJConfig
from .utils import print_main
from pathlib import Path
from transformers.modeling_utils import no_init_weights
from magma.config import MultimodalConfig

LANGUAGE_MODELS = [
"gptj",
]


def gptj_config():
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.attention_layers = ["global"] * 28
config.attention_types = [["global"], 28]
config.num_layers = 28
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True
config.gradient_checkpointing = True
return config


def get_gptj(
def get_gptj(config: MultimodalConfig,
gradient_checkpointing: bool = True,
from_pretrained=False,
from_pretrained="EleutherAI/gpt-j-6B",
) -> torch.nn.Module:
"""
Loads GPTJ language model from HF
"""
print_main("Loading GPTJ language model...")
config = gptj_config()
config.gradient_checkpointing = gradient_checkpointing
gptj_config = GPTJConfig()
gptj_config.gradient_checkpointing = gradient_checkpointing
if gradient_checkpointing:
config.use_cache = False
config.model_device = "cpu"
if from_pretrained:
raise NotImplemented("GPTJ pretrained not implemented")
gptj_config.use_cache = False

if config.deepspeed_config_params['fp16']['enabled'] is True:
model = GPTJForCausalLM.from_pretrained(
from_pretrained, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True, config=gptj_config
)
else:
with no_init_weights():
model = GPTNeoForCausalLM(config=config)
model = AutoModelForCausalLM.from_pretrained(from_pretrained, config=gptj_config)

return model
19 changes: 18 additions & 1 deletion magma/magma.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from copy import deepcopy
from typing import Literal, Optional, List
from torchtyping import TensorType
from torch.nn.modules.container import ModuleList, Sequential
from torch.nn.parameter import Parameter
from transformers.file_utils import ModelOutput
from magma.config import MultimodalConfig

Expand Down Expand Up @@ -40,7 +42,7 @@ def __init__(self, config, device=None):
"cuda" if torch.cuda.is_available() else "cpu"
)
self.config = config
self.lm = get_gptj() #.to(self.device)
self.lm = get_gptj(config) #.to(self.device)
self.seq_len = self.lm.config.max_position_embeddings

self.tokenizer = get_tokenizer("gpt2", sequence_length=self.seq_len)
Expand Down Expand Up @@ -89,6 +91,21 @@ def __init__(self, config, device=None):
**attn_config,
)

#check weights contiguous
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason why we need to check this?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The weights for GPT-J's attention layers are not contiguous, which will raise Tensors must be contiguous in deepspeed.
I find a similar issue here

for name, param in self.named_parameters():
if param.is_contiguous() is False:
path, param = name.rsplit(".",1)
path = path.split('.')
ref = self
while path:
element, path = path[0], path[1:]
if type(ref) in {Sequential, ModuleList}:
ref = ref[int(element)]
else:
ref = getattr(ref, element)
setattr(ref, param, Parameter(getattr(ref, param).contiguous()))
# print(name, getattr(ref, param).is_contiguous())

# freeze parameters
if config.freeze_lm:
for name, param in self.lm.named_parameters(): # freeze lm weights
Expand Down
2 changes: 1 addition & 1 deletion magma/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_tokenizer(name="gpt2", sequence_length=2048):
"""
if name == "gpt2":
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"
tokenizer.model_max_length = sequence_length
# setup lm settings
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
torchtyping
typeguard
git+https://github.com/finetuneanon/transformers.git#egg=transformers
transformers
gdown
tqdm
timm
Expand Down