From 328217e79f8aee71e54384d692da9bebed32d6cd Mon Sep 17 00:00:00 2001 From: zhanghuiyao <1814619459@qq.com> Date: Fri, 15 Nov 2024 15:36:51 +0800 Subject: [PATCH] modify amp --- mindone/transformers/mindspore_adapter/amp.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/mindone/transformers/mindspore_adapter/amp.py b/mindone/transformers/mindspore_adapter/amp.py index b950a2332..ea7eb0323 100644 --- a/mindone/transformers/mindspore_adapter/amp.py +++ b/mindone/transformers/mindspore_adapter/amp.py @@ -1,6 +1,25 @@ import mindspore as ms from mindspore import nn -from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list +from mindspore.train.amp import _auto_black_list + +HALF_UNFRIENDLY_LAYERS = [ + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LayerNorm, + nn.GroupNorm, + nn.SiLU, + nn.GELU, + nn.Softmax, + nn.Sigmoid, + nn.MaxPool1d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.AvgPool3d, + nn.CrossEntropyLoss, +] def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16): @@ -33,25 +52,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16): elif amp_level == "O1": raise NotImplementedError elif amp_level == "O2": - _auto_black_list( - network, - AMP_BLACK_LIST - + [ - nn.GroupNorm, - nn.SiLU, - nn.GELU, - nn.Softmax, - nn.Sigmoid, - nn.MaxPool1d, - nn.MaxPool2d, - nn.MaxPool3d, - nn.AvgPool1d, - nn.AvgPool2d, - nn.AvgPool3d, - nn.CrossEntropyLoss, - ], - dtype, - ) + _auto_black_list(network, HALF_UNFRIENDLY_LAYERS, dtype) elif amp_level == "O3": network.to_float(dtype) else: