forked from pytorch/torchtune
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_model_builders.py
236 lines (206 loc) · 8.46 KB
/
_model_builders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import List, Optional
from torchtune.data._prompt_templates import _get_prompt_template, _TemplateType
from torchtune.models.llama3._component_builders import llama3, lora_llama3
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.modules import TransformerDecoder
from torchtune.modules.peft import LORA_ATTN_MODULES
from torchtune.modules.tokenizers import parse_hf_tokenizer_json
"""
Model builders build specific instantiations using component builders. For example
the llama3_8b model builder uses the llama3 component builder to create the
Llama3 8B model.
"""
def llama3_8b() -> TransformerDecoder:
"""
Builder for creating a Llama3 model initialized w/ the default 8b parameter values.
Returns:
TransformerDecoder: Instantiation of Llama3 8B model
"""
return llama3(
vocab_size=128_256,
num_layers=32,
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
max_seq_len=8192,
intermediate_dim=14336,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
)
def llama3_70b() -> TransformerDecoder:
"""
Builder for creating a Llama3 model initialized w/ the default 70B parameter values.
Returns:
TransformerDecoder: Instantiation of Llama3 70 model
"""
return llama3(
vocab_size=128_256,
num_layers=80,
num_heads=64,
num_kv_heads=8,
embed_dim=8192,
max_seq_len=8192,
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
)
def llama3_tokenizer(
path: str,
special_tokens_path: Optional[str] = None,
max_seq_len: Optional[int] = None,
prompt_template: Optional[_TemplateType] = None,
) -> Llama3Tokenizer:
"""
Tokenizer for Llama3.
Args:
path (str): path to the tokenizer
special_tokens_path (Optional[str]): Path to ``tokenizer.json`` from Hugging Face
model files that contains all registered special tokens, or a local json file
structured similarly. Default is None to use the canonical Llama3 special tokens.
max_seq_len (Optional[int]): maximum sequence length for tokenizing a single list of messages,
after which the input will be truncated. Default is None.
prompt_template (Optional[_TemplateType]): optional specified prompt template.
If a string, it is assumed to be the dotpath of a :class:`~torchtune.data.PromptTemplateInterface`
class. If a dictionary, it is assumed to be a custom prompt template mapping role to the
prepend/append tags.
Returns:
Llama3Tokenizer: Instantiation of the Llama3 tokenizer
"""
special_tokens = (
parse_hf_tokenizer_json(special_tokens_path)
if special_tokens_path is not None
else None
)
template = (
_get_prompt_template(prompt_template) if prompt_template is not None else None
)
return Llama3Tokenizer(
path=path,
special_tokens=special_tokens,
max_seq_len=max_seq_len,
prompt_template=template,
)
def lora_llama3_8b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
quantize_base: bool = False,
use_dora: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3 8B model with LoRA enabled.
The Llama3 defaults are the same as in :func:`~torchtune.models.llama3.llama3_8b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0
quantize_base (bool): Whether to quantize base model weights
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
Returns:
TransformerDecoder: Instantiation of Llama3 8B model with LoRA applied
"""
return lora_llama3(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=32,
num_heads=32,
num_kv_heads=8,
embed_dim=4096,
max_seq_len=8192,
intermediate_dim=14336,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
quantize_base=quantize_base,
use_dora=use_dora,
)
def lora_llama3_70b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
lora_dropout: float = 0.0,
quantize_base: bool = False,
use_dora: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3 70B model with LoRA enabled.
The Llama3 defaults are the same as in :func:`~torchtune.models.llama3.llama3_70b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): dropout probability for the low-rank approximation. Default: 0.0
quantize_base (bool): Whether to quantize base model weights
use_dora (bool): Decompose the LoRA weight into magnitude and direction, as
introduced in "DoRA: Weight-Decomposed Low-Rank Adaptation" (https://arxiv.org/abs/2402.09353).
Returns:
TransformerDecoder: Instantiation of Llama3 70B model with LoRA applied
"""
return lora_llama3(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=80,
num_heads=64,
num_kv_heads=8,
embed_dim=8192,
max_seq_len=8192,
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500_000,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
quantize_base=quantize_base,
use_dora=use_dora,
)
qlora_llama3_8b = partial(lora_llama3_8b, quantize_base=True)
qlora_llama3_8b.__doc__ = """
Builder for creating a Llama3 8B model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_8b` for full API arguments.
"""
qlora_llama3_70b = partial(lora_llama3_70b, quantize_base=True)
qlora_llama3_70b.__doc__ = """
Builder for creating a Llama3 70B model with QLoRA enabled. Base model weights in linear layers
that LoRA is applied to are quantized per the QLoRA paper: https://arxiv.org/abs/2305.14314.
Please see `lora_llama3_70b` for full API arguments.
"""