diff --git a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py index d172199b65..99aedae576 100644 --- a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py +++ b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py @@ -10,7 +10,8 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): # force fp16 precision calculation - _dtype = query.dtype + origin_dtype = query.dtype + dtype = origin_dtype if dtype is None else dtype if dtype is not None: query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype) @@ -19,14 +20,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): attn_weight = ops.softmax( ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5) + attn_mask, ms.float32), axis=-1, - ).astype(_dtype) + ).astype(dtype) else: attn_weight = ops.softmax( ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5), ms.float32), axis=-1 - ).astype(_dtype) + ).astype(dtype) out = ops.matmul(attn_weight, value) - out = out.astype(_dtype) + out = out.astype(origin_dtype) return out