Skip to content

Commit

Permalink
sdxl bug-fix: scaled_dot_product_attention api dtype (#573)
Browse files Browse the repository at this point in the history
* sdxl decoder attn type bug-fix

* fix

---------

Co-authored-by: Fzilan <[email protected]>
  • Loading branch information
Fzilan and Fzilan authored Jul 27, 2024
1 parent c3d1b70 commit 7e4353a
Showing 1 changed file with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None):
# force fp16 precision calculation
_dtype = query.dtype
origin_dtype = query.dtype
dtype = origin_dtype if dtype is None else dtype
if dtype is not None:
query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype)

Expand All @@ -19,14 +20,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None):
attn_weight = ops.softmax(
ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5) + attn_mask, ms.float32),
axis=-1,
).astype(_dtype)
).astype(dtype)
else:
attn_weight = ops.softmax(
ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5), ms.float32), axis=-1
).astype(_dtype)
).astype(dtype)

out = ops.matmul(attn_weight, value)
out = out.astype(_dtype)
out = out.astype(origin_dtype)

return out

Expand Down

0 comments on commit 7e4353a

Please sign in to comment.