Skip to content

Commit

Permalink
npu knn/tnn bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
lizekai committed Jun 12, 2024
1 parent a65aa0f commit aff069d
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 4 deletions.
68 changes: 68 additions & 0 deletions mmcv/ops/csrc/common/pytorch_npu_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,4 +583,72 @@ typedef void (*ReleaseHugeMem)(void *, bool);
} \
} while (false)

#define EXEC_NPU_CMD_SYNC(aclnn_api, ...) \
do { \
static const auto getWorkspaceSizeFuncAddr = \
GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
static const auto initMemAddr = \
GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
static const auto unInitMemAddr = \
GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
TORCH_CHECK( \
getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \
#aclnn_api, " or ", #aclnn_api "GetWorkspaceSize", " not in ", \
GetOpApiLibName(), ", or ", GetOpApiLibName(), "not found."); \
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
uint64_t workspace_size = 0; \
uint64_t *workspace_size_addr = &workspace_size; \
aclOpExecutor *executor = nullptr; \
aclOpExecutor **executor_addr = &executor; \
InitHugeMemThreadLocal initMemFunc = \
reinterpret_cast<InitHugeMemThreadLocal>(initMemAddr); \
UnInitHugeMemThreadLocal unInitMemFunc = \
reinterpret_cast<UnInitHugeMemThreadLocal>(unInitMemAddr); \
if (initMemFunc) { \
initMemFunc(nullptr, false); \
} \
auto converted_params = \
ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
static auto getWorkspaceSizeFunc = \
ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
TORCH_CHECK(workspace_status == 0, \
"call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \
void *workspace_addr = nullptr; \
if (workspace_size != 0) { \
at::TensorOptions options = \
at::TensorOptions(torch_npu::utils::get_npu_device_type()); \
auto workspace_tensor = \
at::empty({workspace_size}, options.dtype(kByte)); \
workspace_addr = const_cast<void *>(workspace_tensor.storage().data()); \
} \
auto acl_call = [converted_params, workspace_addr, workspace_size, \
acl_stream, executor]() -> int { \
typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, \
const aclrtStream); \
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
auto api_ret = \
opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", \
aclGetRecentErrMsg()); \
ReleaseConvertTypes(converted_params); \
ReleaseHugeMem releaseMemFunc = \
reinterpret_cast<ReleaseHugeMem>(releaseMemAddr); \
if (releaseMemFunc) { \
releaseMemFunc(nullptr, false); \
} \
return api_ret; \
}; \
at_npu::native::OpCommand cmd; \
cmd.Name(#aclnn_api); \
cmd.SetCustomHandler(acl_call); \
cmd.Run(); \
cmd.Sync(); \
if (unInitMemFunc) { \
unInitMemFunc(nullptr, false); \
} \
} while (false)

#endif // MMCV_OPS_CSRC_COMMON_PYTORCH_NPU_UTIL_HPP_
12 changes: 10 additions & 2 deletions mmcv/ops/csrc/pytorch/npu/knn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@ using namespace std;
void knn_forward_npu(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
// transpose known from [B, N, 3] to [B, 3, N]
at::Tensor source = xyz.transpose(1, 2).contiguous();
at::Tensor source = xyz.transpose(2, 1).contiguous();
at::Tensor target = new_xyz.contiguous();

at::Tensor dist = at::zeros({target.sizes()[0], target.sizes()[1], source.sizes()[2]}, target.options());
bool is_from_knn = true;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD_SYNC(aclnnKnn, source, target, is_from_knn, dist);

idx.to(at::kLong);
int64_t dim = 2;
bool largest = false;
bool sorted = true;
EXEC_NPU_CMD_SYNC(aclnnTopk, dist, nsample, dim, largest, sorted, dist2, idx);
idx.to(at::kInt);
}

void knn_forward_impl(int b, int n, int m, int nsample, const Tensor xyz,
Expand Down
13 changes: 11 additions & 2 deletions mmcv/ops/csrc/pytorch/npu/three_nn_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,26 @@ using namespace std;
void three_nn_forward_npu(int b, int n, int m, const Tensor unknown,
const Tensor known, Tensor dist2, Tensor idx) {
// transpose known [B, N, 3] -> [B, 3, N]
at::Tensor source = known.transpose(1, 2).contiguous();
at::Tensor source = known.transpose(2, 1).contiguous();
at::Tensor target = unknown.contiguous();
auto originDtype = source.scalar_type();
if (originDtype == at::kHalf) {
source = source.to(at::kFloat);
target = target.to(at::kFloat);
}

at::Tensor dist = at::zeros({target.sizes()[0], target.sizes()[1], source.sizes()[2]}, target.options());
bool is_from_knn = false;
EXEC_NPU_CMD_SYNC(aclnnKnn, source, target, is_from_knn, dist);

idx.to(at::kLong);
int64_t dim = 2;
bool largest = false;
bool sorted = true;
uint32_t nsample = 3;
EXEC_NPU_CMD(aclnnKnn, source, target, nsample, is_from_knn, idx, dist2);
EXEC_NPU_CMD_SYNC(aclnnTopk, dist, nsample, dim, largest, sorted, dist2, idx);
idx.to(at::kInt);

if (originDtype == at::kHalf) {
dist2 = dist2.to(at::kHalf);
}
Expand Down
3 changes: 3 additions & 0 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def forward(ctx,

ext_module.knn_forward(
xyz, center_xyz, idx, dist2, b=B, n=N, m=npoint, nsample=k)
if xyz.device.type != 'npu':
zeros_idx = torch.zeros(B, npoint, k, dtype=torch.int32).npu()
idx.where(dist2 >= 1e10, zeros_idx)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
if torch.__version__ != 'parrots':
Expand Down

0 comments on commit aff069d

Please sign in to comment.