-
Notifications
You must be signed in to change notification settings - Fork 0
/
fsdp_utils.py
35 lines (29 loc) · 1.28 KB
/
fsdp_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
def fsdp_auto_wrap_policy(model, transformer_layer_name):
import functools
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
def lambda_policy_fn(module):
if (
len(list(module.named_children())) == 0
and getattr(module, "weight", None) is not None
and module.weight.requires_grad
):
return True
return False
lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=(
PrefixEncoder,
PromptEncoder,
PromptEmbedding,
transformer_layer_name,
# FullyShardedDataParallelPlugin.get_module_class_from_name(
# model, transformer_layer_name
# ),
),
)
auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
return auto_wrap_policy