-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add stream-k v0.2 #652
add stream-k v0.2 #652
Conversation
468765b
to
650261b
Compare
rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] | ||
tl.store(P_, acc, cache_modifier=".wt") | ||
tl.store(locks + pid, 1, cache_modifier=".wt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copying some of my notes again on gfx90a: For gfx90a the load/stores with cache_modifiers
do not work. Documented here: https://github.com/ROCm/triton-internal/issues/311
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, how we are going to address this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there might be if arch == gfx90a
which we can use for this and the pid renaming, I'll check.
# todo: try use tl.load once cache modifier landed upstream | ||
while tl.atomic_cas(locks + next_pid, 1, 1) != 1: | ||
while (end < tile_iter_end and next_pid < NUM_SMS): | ||
while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also does not work in gfx90a: https://github.com/ROCm/triton-internal/issues/311
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will find a MI250 to test it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add NUM_XCDS so we can switch on/off
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, NUM_XCDS can't help for cache_modifier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we need something else. I'll investigate.
matmul_call_str = f""" | ||
if '{configStr}' not in failed_configs: | ||
rotating_num = tensors['rotating_num'] | ||
locks = torch.zeros(({runs}, {num_sms}), device = "cuda", dtype = torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
locks
can be less than int32
type, we only need 1 byte: uint8
should work. Tensile uses uint8
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we leave this for the next release, as we need a thorough test for this, Thanks !
python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py
Outdated
Show resolved
Hide resolved
python/perf-kernels/streamk/03-matrix-multiplication-stream-k.py
Outdated
Show resolved
Hide resolved
# todo: try use tl.load once cache modifier landed upstream | ||
while tl.atomic_cas(locks + next_pid, 1, 1) != 1: | ||
while (end < tile_iter_end and next_pid < NUM_SMS): | ||
while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, we need something else. I'll investigate.
streamk v0.2:
new streamk tuning script to reduce compiling and profiling time
use load/store cache modifier to reimplement spinning lock
add CI test for streamk-kernel
able to use streampipelineV2