diff --git a/mindone/transformers/mindspore_adapter/amp.py b/mindone/transformers/mindspore_adapter/amp.py index b950a2332..ea7eb0323 100644 --- a/mindone/transformers/mindspore_adapter/amp.py +++ b/mindone/transformers/mindspore_adapter/amp.py @@ -1,6 +1,25 @@ import mindspore as ms from mindspore import nn -from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list +from mindspore.train.amp import _auto_black_list + +HALF_UNFRIENDLY_LAYERS = [ + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.LayerNorm, + nn.GroupNorm, + nn.SiLU, + nn.GELU, + nn.Softmax, + nn.Sigmoid, + nn.MaxPool1d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.AvgPool3d, + nn.CrossEntropyLoss, +] def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16): @@ -33,25 +52,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16): elif amp_level == "O1": raise NotImplementedError elif amp_level == "O2": - _auto_black_list( - network, - AMP_BLACK_LIST - + [ - nn.GroupNorm, - nn.SiLU, - nn.GELU, - nn.Softmax, - nn.Sigmoid, - nn.MaxPool1d, - nn.MaxPool2d, - nn.MaxPool3d, - nn.AvgPool1d, - nn.AvgPool2d, - nn.AvgPool3d, - nn.CrossEntropyLoss, - ], - dtype, - ) + _auto_black_list(network, HALF_UNFRIENDLY_LAYERS, dtype) elif amp_level == "O3": network.to_float(dtype) else: