Skip to content

Commit

Permalink
Interfaces change.
Browse files Browse the repository at this point in the history
  • Loading branch information
DaGaiBa committed Jul 19, 2024
1 parent 4482059 commit 02d23c0
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions mmcv/ops/scatter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,26 @@ def forward(ctx: Any,
reduced from input features that share the same voxel coordinates.
The second is voxel coordinates with shape [M, ndim].
"""
ctx.device = feats.device.type
if ctx.device == 'npu':
import ads_c
voxel_idx = ads_c.point_to_voxel(coors, [], [])
unique_res = ads_c.unique_voxel(voxel_idx)
num_voxels, uniqued_voxel_idx, prefix_sum, \
argsort_coor, _ = unique_res
voxel_coors = ads_c.voxel_to_point(uniqued_voxel_idx, [], [])
voxel_feats, \
compare_mask = ads_c.npu_dynamic_scatter(feats, coors,
prefix_sum,
argsort_coor,
num_voxels,
reduce_type)
ctx.reduce_type = reduce_type
ctx.feats_shape = feats.shape
ctx.save_for_backward(prefix_sum, argsort_coor, compare_mask)
ctx.mark_non_differentiable(voxel_coors)
return voxel_feats, voxel_coors

results = ext_module.dynamic_point_to_voxel_forward(
feats, coors, reduce_type)
(voxel_feats, voxel_coors, point2voxel_map,
Expand All @@ -50,6 +70,19 @@ def forward(ctx: Any,
def backward(ctx: Any,
grad_voxel_feats: torch.Tensor,
grad_voxel_coors: Optional[torch.Tensor] = None) -> tuple:
if ctx.device == 'npu':
import ads_c
prefix_sum, argsort_coor, compare_mask = ctx.saved_tensors
grad_point_feats = torch.zeros(
ctx.feats_shape,
dtype=grad_voxel_feats.dtype,
device=grad_voxel_feats.device)
ads_c.npu_dynamic_scatter_grad(grad_point_feats,
grad_voxel_feats.contiguous(),
prefix_sum, argsort_coor,
compare_mask, ctx.reduce_type)
return grad_point_feats, None, None

(feats, voxel_feats, point2voxel_map,
voxel_points_count) = ctx.saved_tensors
grad_feats = torch.zeros_like(feats)
Expand Down

0 comments on commit 02d23c0

Please sign in to comment.