From 1df43ede08977fd227e72db703216e1338ae411a Mon Sep 17 00:00:00 2001 From: Fzilan Date: Mon, 1 Jul 2024 12:10:48 +0800 Subject: [PATCH 1/2] sdxl decoder attn type bug-fix --- .../gm/modules/transformers/transformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py index d172199b65..acf09eb723 100644 --- a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py +++ b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py @@ -10,7 +10,8 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): # force fp16 precision calculation - _dtype = query.dtype + origin_dtype = query.dtype + dtype = dtype or origin_dtype if dtype is not None: query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype) @@ -19,14 +20,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): attn_weight = ops.softmax( ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5) + attn_mask, ms.float32), axis=-1, - ).astype(_dtype) + ).astype(dtype) else: attn_weight = ops.softmax( ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5), ms.float32), axis=-1 - ).astype(_dtype) + ).astype(dtype) out = ops.matmul(attn_weight, value) - out = out.astype(_dtype) + out = out.astype(origin_dtype) return out From 1cd4b13be3eaca058858e8e1dd2c83f6ae2a6857 Mon Sep 17 00:00:00 2001 From: Fzilan Date: Mon, 22 Jul 2024 10:33:35 +0800 Subject: [PATCH 2/2] fix --- .../stable_diffusion_xl/gm/modules/transformers/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py index acf09eb723..99aedae576 100644 --- a/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py +++ b/examples/stable_diffusion_xl/gm/modules/transformers/transformers.py @@ -11,7 +11,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None): # force fp16 precision calculation origin_dtype = query.dtype - dtype = dtype or origin_dtype + dtype = origin_dtype if dtype is None else dtype if dtype is not None: query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype)