Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Racecheck Bug when tl.min used with tl.sum #4736

Open
thumbe3 opened this issue Sep 17, 2024 · 5 comments
Open

Racecheck Bug when tl.min used with tl.sum #4736

thumbe3 opened this issue Sep 17, 2024 · 5 comments

Comments

@thumbe3
Copy link

thumbe3 commented Sep 17, 2024

import os
import torch
import numpy as np
import triton
import triton.language as tl
import triton


@triton.jit
def compute_min_distance_coord(input_ptr: tl.tensor,
                         coord_ptr: tl.tensor,
                         min_cord_idx_ptr: tl.tensor,
                         BLOCK_SIZE: tl.constexpr):
    pid = tl.program_id(0)
    offs_input_row = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    offs_coord_row = tl.arange(0, 32)
    offs_coord_idxs = tl.arange(0, 8)

    offs_input = offs_input_row[:, None] * 8 + offs_coord_idxs[None, :]
    offs_coord = offs_coord_row[:, None] * 8 + offs_coord_idxs[None, :]
    input = tl.load(input_ptr + offs_input)
    coord = tl.load(coord_ptr + offs_coord)

    #[BLOCK_SIZE,32,8]
    diff = input[:,None,:] - coord[None,:,:]
    dist_sq = diff * diff
    _, min_coord_idxs = tl.min(tl.sum(dist_sq, axis=-1), axis=-1, return_indices=True)
    tl.store(min_cord_idx_ptr + offs_input_row, min_coord_idxs.to(tl.int32))



input = torch.rand(1<<20, 8, dtype= torch.float32).cuda(0)
coordinates = torch.rand(32, 8,  dtype= torch.float32).cuda(0)
out_min_idxs = torch.zeros([1<<20], dtype= torch.int32).cuda(0)

grid = lambda meta: (triton.cdiv(1<<20, meta['BLOCK_SIZE']),)
compute_min_distance_coord[grid](input, coordinates, out_min_idxs, BLOCK_SIZE=512)
torch.cuda.synchronize()

# Equivalent Numpy Code to check correctness
input_np = input.cpu().numpy()
coord_np = coordinates.cpu().numpy()
diff_sq = np.square(input_np[:,None,:]-coord_np[None,:,:])
out_min_idxs_np = np.argmin(np.sum(diff_sq, axis=-1),axis=-1)
print(np.allclose(out_min_idxs_np, out_min_idxs.cpu().numpy()))

In the above code, I try to find the distance between each element of input with 32 coordinates. And return the coordinate with minimum distance to each input(Might be more easier to understand from the numpy code below). When you run this code with race-check tool of compute-sanitizer using (compute-sanitizer --tool=racecheck python script.py). The following output is shown

========= Error: Race reported between Write access at compute_min_distance_coord+0x5ad20 in /usr/local/lib/python3.10/dist-packages/triton/language/standard.py:237
========= and Write access at compute_min_distance_coord+0x5ad20 in /usr/local/lib/python3.10/dist-packages/triton/language/standard.py:237 [6136 hazards]

Error seems to be stemming standard.py which seems to be in the min function
image

 I am not facing correctness issue with this code at the moment. But I have faced correctness issues with other kernels using similar combination tl.sum with tl.min

lijinpei added a commit to lijinpei/triton that referenced this issue Sep 22, 2024
lijinpei added a commit to lijinpei/triton that referenced this issue Sep 22, 2024
@lijinpei
Copy link
Contributor

I have created a WIP patch lijinpei@3fe20ba which solves the provided script.py and fail no case in python/test/unit/language/test_core.py (except 'python/test/unit/language/test_core.py::test_dot[1-128-128-64-4-False-False-chain-dot-ieee-float8e5-float32-1]' already failed on main branch on my machine).
Can you try the patch on 'correctness issues with other kernels using similar combination tl.sum with tl.min', or help to provide it as a unit test? I think gate keepers won't accept the patch without a unit test.

@Jokeren
Copy link
Contributor

Jokeren commented Sep 22, 2024

We likely won't accept your solution even with a unit test. I don't see correctness issues.

@Jokeren
Copy link
Contributor

Jokeren commented Sep 22, 2024

But I have faced correctness issues with other kernels using similar combination tl.sum with tl.min

Since having data races in this specific case doesn't cause correctness problems for you IIUC, it might be better to provide your code with real issues.

Data races could be triggered by having the same location being accessed by multiple threads with the same value, which is fine in Triton.

@peterbell10
Copy link
Contributor

Out of curiosity I profiled the repro before and after the change I do see a small (~1%) speedup that reproduces consistently.

@Jokeren
Copy link
Contributor

Jokeren commented Sep 24, 2024

Out of curiosity I profiled the repro before and after the change I do see a small (~1%) speedup that reproduces consistently.

I think we need to run internal regression benchmarks instead of external ones

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants