Skip to content

Commit

Permalink
fix scatter and gatherops, use expand instead of cat to avoid alloc u…
Browse files Browse the repository at this point in the history
…seless memory.
  • Loading branch information
suisiyuan committed Aug 8, 2024
1 parent 2636150 commit 003ac52
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions byte_micro_perf/backends/module_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,11 +550,9 @@ def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device):

index = [i for i in range(batch_size)]
random.shuffle(index)
index_tensor = torch.cat(
[torch.full((1, tensor_len), i, dtype=torch.int64, device=xpu_device) for i in index],
dim=0
)

index_tensor = torch.tensor(index, dtype=torch.int64, device=xpu_device)
index_tensor = index_tensor.reshape(-1, 1).expand(-1, tensor_len)

return [dst_tensor, index_tensor, src_tensor]


Expand Down Expand Up @@ -602,10 +600,8 @@ def custom_create_tensors(self, input_shapes, torch_dtype, xpu_device):

index = [i for i in range(batch_size)]
random.shuffle(index)
index_tensor = torch.cat(
[torch.full((1, tensor_len), i, dtype=torch.int64, device=xpu_device) for i in index],
dim=0
)
index_tensor = torch.tensor(index, dtype=torch.int64, device=xpu_device)
index_tensor = index_tensor.reshape(-1, 1).expand(-1, tensor_len)

return [dst_tensor, index_tensor, src_tensor]

Expand Down

0 comments on commit 003ac52

Please sign in to comment.