Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdxl bug-fix: scaled_dot_product_attention api dtype #573

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None):
# force fp16 precision calculation
_dtype = query.dtype
origin_dtype = query.dtype
dtype = origin_dtype if dtype is None else dtype
if dtype is not None:
query, key, value = query.astype(dtype), key.astype(dtype), value.astype(dtype)

Expand All @@ -19,14 +20,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dtype=None):
attn_weight = ops.softmax(
ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5) + attn_mask, ms.float32),
axis=-1,
).astype(_dtype)
).astype(dtype)
else:
attn_weight = ops.softmax(
ops.cast(ops.matmul(query, key.swapaxes(-2, -1)) / (query.shape[-1] ** 0.5), ms.float32), axis=-1
).astype(_dtype)
).astype(dtype)

out = ops.matmul(attn_weight, value)
out = out.astype(_dtype)
out = out.astype(origin_dtype)

return out

Expand Down