Skip to content

Commit

Permalink
Fix broken tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Aug 24, 2022
1 parent f00e291 commit f4e589d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

@triton.jit
def add_kernel(x_ptr, y_ptr, output_ptr,
block_size: tl.constexpr, n_elements: tl.constexpr):
block_size: tl.constexpr, n_elements: tl.constexpr):
pid = tl.program_id(axis=0) # we use a 1d launch grid so axis is 0
block_start = pid * block_size
offsets = block_start + tl.arange(0, block_size)
Expand Down Expand Up @@ -98,15 +98,15 @@ class TritonKernelCallTest(parameterized.TestCase):
])
def test_add_vectors(self, size, dtype, block_size):

grid = lambda meta: (size // meta["BLOCK_SIZE"] + 1,)
grid = lambda meta: (size // meta["block_size"] + 1,)
k1, k2 = random.split(random.PRNGKey(0), 2)
if dtype in {"float32", "float16", "float64"}:
x, y = random.normal(k1, [size], dtype=dtype), random.normal(k2, [size], dtype=dtype)
elif dtype in {"int32", "int64"}:
x, y = random.randint(k1, [size], -100, 100, dtype=dtype), random.randint(k2, [size], -100, 100, dtype=dtype)

out = triton_call(x, y, kernel=add_kernel, out_shape=x,
grid=grid, BLOCK_SIZE=block_size, n_elements=size)
grid=grid, block_size=block_size, n_elements=size)
expected = x + y
np.testing.assert_allclose(out, expected)

Expand Down

0 comments on commit f4e589d

Please sign in to comment.