From 0103b995e085209501820421d95c44cd2dfa18a3 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Tue, 4 Jul 2023 14:35:55 +0800 Subject: [PATCH 1/3] [shardformer] added tests --- tests/kit/model_zoo/transformers/__init__.py | 1 + tests/kit/model_zoo/transformers/vit.py | 54 +++++++++++++++++++ .../test_model/test_shard_vit.py | 1 + 3 files changed, 56 insertions(+) create mode 100644 tests/kit/model_zoo/transformers/vit.py diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4aa01abe13ee..a298767d12e7 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -5,3 +5,4 @@ from .llama import * from .opt import * from .t5 import * +from .vit import * diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py new file mode 100644 index 000000000000..6f186c9d55a9 --- /dev/null +++ b/tests/kit/model_zoo/transformers/vit.py @@ -0,0 +1,54 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence VIT +# =============================== + + +# define data gen function +def data_gen(): + + pixel_values = torch.randn(1, 3, 224, 224) + return dict(pixel_values = pixel_values) + +# define output transform function +output_transform_fn = lambda x: x + +# function to get the loss +loss_fn_for_vit_model = lambda x : x.pooler_output.mean() +loss_fn = lambda x : x.loss + +config = transformers.ViTConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4) + +# register the following models +# transformers.ViTModel, +# transformers.ViTForMaskedImageModeling, +# transformers.ViTForImageClassification, +model_zoo.register(name = 'transformers_vit', + model_fn = lambda : transformers.ViTModel(config), + data_gen_fn = data_gen, + output_transform_fn = output_transform_fn, + loss_fn = loss_fn_for_vit_model, + model_attribute = ModelAttribute(has_control_flow=True)) + +model_zoo.register(name = 'transformers_vit_for_masked_image_modeling', + model_fn = lambda : transformers.ViTForMaskedImageModeling(config), + data_gen_fn = data_gen, + output_transform_fn = output_transform_fn, + loss_fn = loss_fn, + model_attribute = ModelAttribute(has_control_flow=True)) + +model_zoo.register(name = 'transformers_vit_for_image_classification', + model_fn = lambda : transfomers.ViTForImageClassification(config), + data_gen_fn = data_gen, + output_transform_fn = output_transform_fn, + loss_fn = loss_fn, + model_attribute = ModelAttribute(has_control_flow=True)) + + diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index af1605b6b659..58a4e156b4e2 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -37,6 +37,7 @@ def check_vit(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) From 61730022c7397b82913b78966cce5f4e19f06ec6 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 6 Jul 2023 10:59:42 +0800 Subject: [PATCH 2/3] [shardformer] vit test finish and support --- colossalai/shardformer/policies/autopolicy.py | 8 + colossalai/shardformer/policies/vit.py | 170 +++++++++--------- tests/kit/model_zoo/transformers/vit.py | 28 +-- .../test_model/test_shard_vit.py | 63 ++++--- 4 files changed, 157 insertions(+), 112 deletions(-) diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 8051433e8d71..f49a552c82b3 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -71,6 +71,14 @@ class PolicyLocation: "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + # ViT + "transformers.models.vit.modeling_vit.ViTModel": + PolicyLocation(file_name="vit", class_name="ViTPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": + PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": + PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"), + # OPT "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index eaebe2eee0ba..f66562dea3d2 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -2,11 +2,11 @@ import torch.nn as nn -from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +from colossalai.shardformer.layer import DropoutForReplicatedInput, DropoutForParallelInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy'] +__all__ = ['ViTPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] class ViTPolicy(Policy): @@ -15,96 +15,104 @@ def config_sanity_check(self): pass def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer - base_policy = { - ViTEmbeddings: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]), - ViTLayer: - ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - }, + policy = {} + + if self.shard_config.enable_tensor_parallelism: + policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=DropoutForParallelInput, - ), - ]), - } - - # optimization configuration - if self.shard_config.enable_fused_normalization: - base_policy[ViTAttention].sub_module_replacement.extend([ - SubModuleReplacementDescription( - suffix="layernorm_before", - target_module=FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layernorm_after", - target_module=FusedLayerNorm, + suffix="dropout", + target_module=DropoutForReplicatedInput, + ) + ]) + + policy[ViTLayer] = ModulePolicyDescription( + attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size//self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=DropoutForReplicatedInput, + ), + ] ) - ]) - base_policy[ViTModel].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="layernorm", - target_module=FusedLayerNorm, - )) - - return base_policy + return policy + + def new_model_class(self): return None def postprocess(self): return self.model + +class ViTForImageClassificationPolicy(ViTPolicy): + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForImageClassification + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + new_item = { + ViTForImageClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription(suffix="classifier", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + return policy + +class ViTForMaskedImageModelingPolicy(ViTPolicy): + + def module_policy(self): + policy = super().module_policy() + return policy + + + + diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py index 6f186c9d55a9..1c86c7ebc742 100644 --- a/tests/kit/model_zoo/transformers/vit.py +++ b/tests/kit/model_zoo/transformers/vit.py @@ -7,24 +7,30 @@ # Register single-sentence VIT # =============================== +config = transformers.ViTConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4) # define data gen function def data_gen(): - pixel_values = torch.randn(1, 3, 224, 224) return dict(pixel_values = pixel_values) +def data_gen_for_masked_image_modeling(): + data = data_gen() + num_patches = (config.image_size // config.patch_size) ** 2 + bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + data['bool_masked_pos'] = bool_masked_pos + return data + # define output transform function output_transform_fn = lambda x: x # function to get the loss loss_fn_for_vit_model = lambda x : x.pooler_output.mean() -loss_fn = lambda x : x.loss - -config = transformers.ViTConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4) +loss_fn_for_image_classification = lambda x : x.logits.mean() +loss_fn_for_masked_image_modeling = lambda x : x.loss # register the following models # transformers.ViTModel, @@ -39,16 +45,16 @@ def data_gen(): model_zoo.register(name = 'transformers_vit_for_masked_image_modeling', model_fn = lambda : transformers.ViTForMaskedImageModeling(config), - data_gen_fn = data_gen, + data_gen_fn = data_gen_for_masked_image_modeling, output_transform_fn = output_transform_fn, - loss_fn = loss_fn, + loss_fn = loss_fn_for_masked_image_modeling, model_attribute = ModelAttribute(has_control_flow=True)) model_zoo.register(name = 'transformers_vit_for_image_classification', - model_fn = lambda : transfomers.ViTForImageClassification(config), + model_fn = lambda : transformers.ViTForImageClassification(config), data_gen_fn = data_gen, output_transform_fn = output_transform_fn, - loss_fn = loss_fn, + loss_fn = loss_fn_for_image_classification, model_attribute = ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index 58a4e156b4e2..a96fd02ae746 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -1,9 +1,18 @@ +import os + import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -12,45 +21,59 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output) - + assert_hf_output_close(org_output, shard_output, atol=1e-4, rtol=1e-4) # do backward org_loss.backward() shard_loss.backward() - # check grad - org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad - - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" - assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + # unwrap model + if org_model.__class__.__name__ == 'ViTModel': + vit_model = org_model + shard_vit_model = sharded_model + else: + vit_model = org_model.vit + shard_vit_model = sharded_model.vit -def check_vit(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + # check attention grad + org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - print(sub_model_zoo) for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - torch.cuda.empty_cache() +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_test() + + @pytest.mark.dist @pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): - spawn(check_vit, 4) + spawn(check_vit, 2) if __name__ == "__main__": From de5647cdf7bc31e5b8d82c51c8a6a8d91021d1b8 Mon Sep 17 00:00:00 2001 From: klhhhhh <1412841649@qq.com> Date: Thu, 6 Jul 2023 12:01:45 +0800 Subject: [PATCH 3/3] fix attention dropout --- colossalai/shardformer/policies/vit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index f66562dea3d2..7b035afae22c 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -55,7 +55,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ), SubModuleReplacementDescription( suffix="attention.attention.dropout", - target_module=DropoutForReplicatedInput, + target_module=DropoutForParallelInput, ), SubModuleReplacementDescription( suffix="attention.output.dense",