Skip to content

Commit

Permalink
Merge pull request #40 from DaGaiBa/rc41.x
Browse files Browse the repository at this point in the history
Bugfix of NPU adapter of nms3d
  • Loading branch information
momo609 authored Jun 22, 2024
2 parents 0356569 + 5a75cd1 commit 3bc1aa1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 15 deletions.
17 changes: 9 additions & 8 deletions mmcv/ops/csrc/pytorch/npu/nms3d_normal_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,24 @@
using namespace NPU_NAME_SPACE;

void iou3d_nms3d_normal_forward_npu(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh) {
Tensor &num_out, float nms_overlap_thresh) {
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
const double iou_threshold = nms_overlap_thresh;
at::Tensor mask =
at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3dNormal, boxes, nms_overlap_thresh, mask);
EXEC_NPU_CMD(aclnnNms3dNormal, boxes, iou_threshold, mask);

keep = at::zeros({box_num}, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
Tensor keep_t = at::zeros({box_num}, mask.options());
Tensor num_out_t = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t);
num_out.fill_(num_out_t.item().toLong());
keep.copy_(keep_t);
}

void iou3d_nms3d_normal_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num,
float nms_overlap_thresh);
Tensor &num_out, float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_normal_forward_impl,
iou3d_nms3d_normal_forward_npu);
18 changes: 11 additions & 7 deletions mmcv/ops/csrc/pytorch/npu/nms3d_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@ using namespace std;

constexpr int32_t BOX_DIM = 7;

void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &keep_num,
void iou3d_nms3d_forward_npu(const Tensor boxes, Tensor &keep, Tensor &num_out,
float nms_overlap_thresh) {
TORCH_CHECK((boxes.sizes()[1] == BOX_DIM),
"Input boxes shape should be (N, 7)");
int32_t box_num = boxes.size(0);
int32_t data_align = 16;
int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align;
const double iou_threshold = nms_overlap_thresh;
at::Tensor mask =
at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort));
EXEC_NPU_CMD(aclnnNms3d, boxes, nms_overlap_thresh, mask);
keep = at::zeros({box_num}, mask.options());
keep_num = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, keep_num);
EXEC_NPU_CMD(aclnnNms3d, boxes, iou_threshold, mask);

Tensor keep_t = at::zeros({box_num}, mask.options());
Tensor num_out_t = at::zeros(1, mask.options());
EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep_t, num_out_t);
num_out.fill_(num_out_t.item().toLong());
keep.copy_(keep_t);
}

void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep,
Tensor &keep_num, float nms_overlap_thresh);
void iou3d_nms3d_forward_impl(const Tensor boxes, Tensor &keep, Tensor &num_out,
float nms_overlap_thresh);

REGISTER_NPU_IMPL(iou3d_nms3d_forward_impl, iou3d_nms3d_forward_npu);

0 comments on commit 3bc1aa1

Please sign in to comment.