From 7e4353a6e160239987db3ce51d93af4d24927aa9 Mon Sep 17 00:00:00 2001 From: fzilan <33061146+Fzilan@users.noreply.github.com> Date: Sat, 27 Jul 2024 11:23:52 +0800 Subject: [PATCH] sdxl bug-fix: scaled_dot_product_attention api dtype (#573) * sdxl decoder attn type bug-fix * fix --------- Co-authored-by: Fzilan --- .../gm/modules/transformers/transformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py index d172199b65..99aedae576 100644 --- a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py +++ b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py @@ -10,7 +10,8 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): # force fp16 precision calculation - _dtype = query.dtype + origin_dtype = query.dtype + dtype = origin_dtype if dtype is None else dtype if dtype is not None: query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype) @@ -19,14 +20,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): attn_weight = ops.softmax( ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5) + attn_mask, ms.float32), axis=-1, - ).astype(_dtype) + ).astype(dtype) else: attn_weight = ops.softmax( ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5), ms.float32), axis=-1 - ).astype(_dtype) + ).astype(dtype) out = ops.matmul(attn_weight, value) - out = out.astype(_dtype) + out = out.astype(origin_dtype) return out