Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/vit support #4182

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ class PolicyLocation:
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),

# ViT
"transformers.models.vit.modeling_vit.ViTModel":
PolicyLocation(file_name="vit", class_name="ViTPolicy"),
"transformers.models.vit.modeling_vit.ViTForImageClassification":
PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"),
"transformers.models.vit.modeling_vit.ViTForMaskedImageModeling":
PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"),

# OPT
"transformers.models.opt.modeling_opt.OPTModel":
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
Expand Down
170 changes: 89 additions & 81 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import torch.nn as nn

from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['ViTPolicy']
__all__ = ['ViTPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy']


class ViTPolicy(Policy):
Expand All @@ -15,96 +15,104 @@ def config_sanity_check(self):
pass

def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer

base_policy = {
ViTEmbeddings:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
]),
ViTLayer:
ModulePolicyDescription(attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
},
policy = {}

if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=DropoutForParallelInput,
),
]),
}

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[ViTAttention].sub_module_replacement.extend([
SubModuleReplacementDescription(
suffix="layernorm_before",
target_module=FusedLayerNorm,
),
SubModuleReplacementDescription(
suffix="layernorm_after",
target_module=FusedLayerNorm,
suffix="dropout",
target_module=DropoutForReplicatedInput,
)
])

policy[ViTLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=DropoutForReplicatedInput,
),
FrankLeeeee marked this conversation as resolved.
Show resolved Hide resolved
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=DropoutForReplicatedInput,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=DropoutForReplicatedInput,
),
]
)
])
base_policy[ViTModel].sub_module_replacement.append(
SubModuleReplacementDescription(
suffix="layernorm",
target_module=FusedLayerNorm,
))

return base_policy

return policy


def new_model_class(self):
return None

def postprocess(self):
return self.model

class ViTForImageClassificationPolicy(ViTPolicy):

def module_policy(self):
from transformers.models.vit.modeling_vit import ViTForImageClassification

policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
new_item = {
ViTForImageClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(suffix="classifier",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy

class ViTForMaskedImageModelingPolicy(ViTPolicy):

def module_policy(self):
policy = super().module_policy()
return policy




1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .llama import *
from .opt import *
from .t5 import *
from .vit import *
60 changes: 60 additions & 0 deletions tests/kit/model_zoo/transformers/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

# ===============================
# Register single-sentence VIT
# ===============================

config = transformers.ViTConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4)

# define data gen function
def data_gen():
pixel_values = torch.randn(1, 3, 224, 224)
return dict(pixel_values = pixel_values)

def data_gen_for_masked_image_modeling():
data = data_gen()
num_patches = (config.image_size // config.patch_size) ** 2
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
data['bool_masked_pos'] = bool_masked_pos
return data

# define output transform function
output_transform_fn = lambda x: x

# function to get the loss
loss_fn_for_vit_model = lambda x : x.pooler_output.mean()
loss_fn_for_image_classification = lambda x : x.logits.mean()
loss_fn_for_masked_image_modeling = lambda x : x.loss

# register the following models
# transformers.ViTModel,
# transformers.ViTForMaskedImageModeling,
# transformers.ViTForImageClassification,
model_zoo.register(name = 'transformers_vit',
model_fn = lambda : transformers.ViTModel(config),
data_gen_fn = data_gen,
output_transform_fn = output_transform_fn,
loss_fn = loss_fn_for_vit_model,
model_attribute = ModelAttribute(has_control_flow=True))

model_zoo.register(name = 'transformers_vit_for_masked_image_modeling',
model_fn = lambda : transformers.ViTForMaskedImageModeling(config),
data_gen_fn = data_gen_for_masked_image_modeling,
output_transform_fn = output_transform_fn,
loss_fn = loss_fn_for_masked_image_modeling,
model_attribute = ModelAttribute(has_control_flow=True))

model_zoo.register(name = 'transformers_vit_for_image_classification',
model_fn = lambda : transformers.ViTForImageClassification(config),
data_gen_fn = data_gen,
output_transform_fn = output_transform_fn,
loss_fn = loss_fn_for_image_classification,
model_attribute = ModelAttribute(has_control_flow=True))


62 changes: 43 additions & 19 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import os

import pytest
import torch

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward

Expand All @@ -12,44 +21,59 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)

assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4)
# do backward
org_loss.backward()
shard_loss.backward()

# check grad
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.vit

def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# check attention grad
org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad
shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight

if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"


@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_vit_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()


def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_test()


@pytest.mark.dist
@pytest.mark.skip
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)
spawn(check_vit, 2)


if __name__ == "__main__":
Expand Down
Loading