-
Notifications
You must be signed in to change notification settings - Fork 1
/
basic_model_loader.py
131 lines (107 loc) · 6.06 KB
/
basic_model_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
import os
from loguru import logger
import json
from peft import PeftModel
from .utils import log
from huggingface_hub import hf_hub_download
try:
from auto_gptq import AutoGPTQForCausalLM
except ImportError:
from transformers import AutoModelForCausalLM as AutoGPTQForCausalLM
log(logger.warning, "Failed to import auto_gptq")
def load_tokenizer(dir_or_model):
"""
This function is used to load the tokenizer for a specific pre-trained model.
Args:
dir_or_model: It can be either a directory containing the pre-training model configuration details or a pretrained model.
Returns:
It returns a tokenizer that can convert text to tokens for the specific model input.
"""
log(logger.debug, f"Loading tokenizer for {dir_or_model}")
is_lora_dir = os.path.isfile(os.path.join(dir_or_model, "adapter_config.json"))
if is_lora_dir:
loaded_json = json.load(open(os.path.join(dir_or_model, "adapter_config.json"), "r"))
model_name = loaded_json["base_model_name_or_path"]
else:
model_name = dir_or_model
if os.path.isfile(os.path.join(dir_or_model, "config.json")):
loaded_json = json.load(open(os.path.join(dir_or_model, "config.json"), "r"))
if "_name_or_path" in loaded_json:
model_name = loaded_json["_name_or_path"]
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
except Exception:
# tokenizer name is stored in adapter_config.json instead
# download the adapter_config file
try:
adapter_config = hf_hub_download(model_name, 'adapter_config.json')
name = json.load(open(adapter_config, "r"))["base_model_name_or_path"]
except Exception:
if "Llama-2" in model_name:
name = 'meta-llama/Llama-2-7b-hf'
else:
name = "microsoft/phi-2"
tokenizer = AutoTokenizer.from_pretrained(name)
if tokenizer.pad_token is None:
log(logger.debug, "Setting pad token to eos token")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
return tokenizer
def load_model(dir_or_model, classification=False, token_classification=False, return_tokenizer=False, dtype=torch.bfloat16, load_dtype=True,
peft_config=None, device_map="auto", revision='main', attn_implementation='flash_attention_2',
trust_remote_code=True):
"""
This function is used to load a model based on several parameters including the type of task it is targeted to perform.
Args:
dir_or_model: It can be either a directory containing the pre-training model configuration details or a pretrained model.
classification (bool): If True, loads the model for sequence classification.
token_classification (bool): If True, loads the model for token classification.
return_tokenizer (bool): If True, returns the tokenizer along with the model.
dtype: The data type that PyTorch should use internally to store the model’s parameters and do the computation.
load_dtype (bool): If False, sets dtype as torch.float32 regardless of the passed dtype value.
peft_config: Configuration details for Peft models.
Returns:
It returns a model for the required task along with its tokenizer, if specified.
"""
log(logger.debug, f"Loading model for {dir_or_model} with {classification}, {dtype}, {load_dtype}")
is_lora_dir = os.path.isfile(os.path.join(dir_or_model, "adapter_config.json"))
if not load_dtype:
dtype = torch.float32
if is_lora_dir:
loaded_json = json.load(open(os.path.join(dir_or_model, "adapter_config.json"), "r"))
model_name = loaded_json["base_model_name_or_path"]
else:
model_name = dir_or_model
original_model_name = model_name
if classification:
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=trust_remote_code, torch_dtype=dtype, use_auth_token=True, device_map=device_map, revision=revision) # to investigate: calling torch_dtype here fails.
elif token_classification:
model = AutoModelForTokenClassification.from_pretrained(model_name, trust_remote_code=trust_remote_code, torch_dtype=dtype, use_auth_token=True, device_map=device_map, revision=revision)
else:
if model_name.endswith("GPTQ") or model_name.endswith("GGML"):
model = AutoGPTQForCausalLM.from_quantized(model_name,
use_safetensors=True,
trust_remote_code=trust_remote_code,
# use_triton=True, # breaks currently, unfortunately generation time of the GPTQ model is quite slow
quantize_config=None, device_map=device_map)
else:
try:
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code, torch_dtype=dtype, use_auth_token=True, device_map=device_map, revision=revision,
attn_implementation=attn_implementation)
except Exception:
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code, torch_dtype=dtype, use_auth_token=True, device_map=device_map, revision=revision)
if is_lora_dir:
try:
model = PeftModel.from_pretrained(model, dir_or_model, attn_implementation=attn_implementation)
except Exception:
model = PeftModel.from_pretrained(model, dir_or_model)
try:
tokenizer = load_tokenizer(original_model_name)
model.config.pad_token_id = tokenizer.pad_token_id
except Exception:
pass
if return_tokenizer:
return model, load_tokenizer(original_model_name)
return model