Skip to content

Commit

Permalink
modify amp
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghuiyao committed Nov 15, 2024
1 parent 8048e81 commit 328217e
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions mindone/transformers/mindspore_adapter/amp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,25 @@
import mindspore as ms
from mindspore import nn
from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list
from mindspore.train.amp import _auto_black_list

HALF_UNFRIENDLY_LAYERS = [
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.LayerNorm,
nn.GroupNorm,
nn.SiLU,
nn.GELU,
nn.Softmax,
nn.Sigmoid,
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.CrossEntropyLoss,
]


def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16):
Expand Down Expand Up @@ -33,25 +52,7 @@ def auto_mixed_precision(network, amp_level="O0", dtype=ms.float16):
elif amp_level == "O1":
raise NotImplementedError
elif amp_level == "O2":
_auto_black_list(
network,
AMP_BLACK_LIST
+ [
nn.GroupNorm,
nn.SiLU,
nn.GELU,
nn.Softmax,
nn.Sigmoid,
nn.MaxPool1d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.AvgPool3d,
nn.CrossEntropyLoss,
],
dtype,
)
_auto_black_list(network, HALF_UNFRIENDLY_LAYERS, dtype)
elif amp_level == "O3":
network.to_float(dtype)
else:
Expand Down

0 comments on commit 328217e

Please sign in to comment.