diff --git a/torchbenchmark/util/kernels/triton_fused_attention.py b/torchbenchmark/util/kernels/triton_fused_attention.py index 872ba8fac..0c84fa25d 100644 --- a/torchbenchmark/util/kernels/triton_fused_attention.py +++ b/torchbenchmark/util/kernels/triton_fused_attention.py @@ -62,7 +62,7 @@ def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): else: desc_x = self.cuda_descriptors[name] buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) - self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.numpy()) + self.fill_1d_tma_descriptor_inner(ptr, dim, block_dim, element_size, buf_x.data_ptr()) desc_x.copy_(buf_x, non_blocking=True) @@ -75,7 +75,7 @@ def fill_2d_tma_descriptor(self, name, ptr, dim1, dim0, block_dim1, block_dim0, else: desc_x = self.cuda_descriptors[name] buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) - self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.numpy()) + self.fill_2d_tma_descriptor_inner(ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr()) desc_x.copy_(buf_x, non_blocking=True)