Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stream-k v0.2 #652

Merged
merged 7 commits into from
Oct 31, 2024
Merged

add stream-k v0.2 #652

merged 7 commits into from
Oct 31, 2024

Conversation

xiaohuguo2023
Copy link
Member

streamk v0.2:

  • new streamk tuning script to reduce compiling and profiling time

  • use load/store cache modifier to reimplement spinning lock

  • add CI test for streamk-kernel

  • able to use streampipelineV2

rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N)
P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]
tl.store(P_, acc, cache_modifier=".wt")
tl.store(locks + pid, 1, cache_modifier=".wt")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copying some of my notes again on gfx90a: For gfx90a the load/stores with cache_modifiers do not work. Documented here: https://github.com/ROCm/triton-internal/issues/311

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, how we are going to address this ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be if arch == gfx90a which we can use for this and the pid renaming, I'll check.

# todo: try use tl.load once cache modifier landed upstream
while tl.atomic_cas(locks + next_pid, 1, 1) != 1:
while (end < tile_iter_end and next_pid < NUM_SMS):
while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also does not work in gfx90a: https://github.com/ROCm/triton-internal/issues/311

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will find a MI250 to test it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add NUM_XCDS so we can switch on/off

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, NUM_XCDS can't help for cache_modifier

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need something else. I'll investigate.

python/perf-kernels/streamk/streamk_kernel.py Outdated Show resolved Hide resolved
matmul_call_str = f"""
if '{configStr}' not in failed_configs:
rotating_num = tensors['rotating_num']
locks = torch.zeros(({runs}, {num_sms}), device = "cuda", dtype = torch.int32)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

locks can be less than int32 type, we only need 1 byte: uint8 should work. Tensile uses uint8.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we leave this for the next release, as we need a thorough test for this, Thanks !

# todo: try use tl.load once cache modifier landed upstream
while tl.atomic_cas(locks + next_pid, 1, 1) != 1:
while (end < tile_iter_end and next_pid < NUM_SMS):
while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we need something else. I'll investigate.

@xiaohuguo2023 xiaohuguo2023 merged commit 1d60b05 into main_perf Oct 31, 2024
4 checks passed
@xiaohuguo2023 xiaohuguo2023 deleted the streamkv0.2 branch October 31, 2024 19:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants