diff --git a/pypots/nn/modules/reformer/local_attention.py b/pypots/nn/modules/reformer/local_attention.py index a617b9ba..f84e6e56 100644 --- a/pypots/nn/modules/reformer/local_attention.py +++ b/pypots/nn/modules/reformer/local_attention.py @@ -13,7 +13,7 @@ from einops import rearrange from einops import repeat, pack, unpack from torch import nn, einsum -from torch.cuda.amp import autocast +from torch.amp import autocast TOKEN_SELF_ATTN_VALUE = -5e4 @@ -28,7 +28,7 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -@autocast(enabled=False) +@autocast("cuda", enabled=False) def apply_rotary_pos_emb(q, k, freqs, scale=1): q_len = q.shape[-2] q_freqs = freqs[..., -q_len:, :] @@ -95,7 +95,7 @@ def __init__(self, dim, scale_base=None, use_xpos=False): scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.register_buffer("scale", scale, persistent=False) - @autocast(enabled=False) + @autocast("cuda", enabled=False) def forward(self, x): seq_len, device = x.shape[-2], x.device