Skip to content

Commit

Permalink
Merge pull request #82 from Bosco-lab/rc4main
Browse files Browse the repository at this point in the history
add pixel_group_npu
  • Loading branch information
hust17yixuan authored Nov 29, 2024
2 parents f27fff7 + cf23718 commit 2ed1460
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 55 deletions.
55 changes: 55 additions & 0 deletions mmcv/ops/csrc/pytorch/npu/pixel_group_npu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "pytorch_npu_helper.hpp"

using namespace NPU_NAME_SPACE;
using namespace std;

vector<vector<float>> pixel_group_npu(Tensor score, Tensor mask, Tensor embedding,
Tensor kernel_label, Tensor kernel_contour,
int kernel_region_num, float distance_threshold) {
TORCH_CHECK(score.dim() == 2, "score.dim() must be 2, but got: ", score.dim());
TORCH_CHECK(mask.dim() == 2, "mask.dim() must be 2, but got: ", mask.dim());
TORCH_CHECK(embedding.dim() == 3, "embedding.dim() must be 3, but got: ", embedding.dim());
TORCH_CHECK(kernel_label.dim() == 2, "kernel_label.dim() must be 2, but got: ", kernel_label.dim());
TORCH_CHECK(kernel_contour.dim() == 2, "kernel_contour.dim() must be 2, but got: ", kernel_contour.dim());

auto label_size = kernel_label.sizes();
auto height = label_size[0];
auto width = label_size[1];

c10::SmallVector<int64_t, 8> point_vector_size = {kernel_region_num, 2};
c10::SmallVector<int64_t, 8> label_updated_size = {height, width};
at::Tensor point_vector = at::zeros(point_vector_size, score.options());
at::Tensor label_updated = at::empty(label_updated_size, kernel_label.options());

EXEC_NPU_CMD(aclnnPixelGroup, score, mask, embedding, kernel_label, kernel_contour,
kernel_region_num, distance_threshold, point_vector, label_updated);

std::vector<std::vector<float>> pixel_assignment(kernel_region_num);
at::Tensor point_vector_cpu = point_vector.to(at::kCPU);
at::Tensor label_updated_cpu = label_updated.to(at::kCPU);

for (int32_t l = 0; l < kernel_region_num; l++) {
pixel_assignment[l].push_back(point_vector_cpu[l][0].item<float>());
pixel_assignment[l].push_back(point_vector_cpu[l][1].item<float>());
if (pixel_assignment[l][1] > 0) {
pixel_assignment[l][0] /= pixel_assignment[l][1];
}
if (l > 0) {
at::Tensor valid_mask = (label_updated_cpu == l);
at::Tensor indices = at::nonzero(valid_mask);
for (int32_t i = 0; i < indices.size(0); i++) {
auto x = indices[i][0].item<int32_t>();
auto y = indices[i][1].item<int32_t>();
pixel_assignment[l].push_back(y);
pixel_assignment[l].push_back(x);
}
}
}
return pixel_assignment;
}

vector<vector<float>> pixel_group_impl(Tensor score, Tensor mask, Tensor embedding,
Tensor kernel_label, Tensor kernel_contour,
int kernel_region_num, float distance_threshold);

REGISTER_NPU_IMPL(pixel_group_impl, pixel_group_npu);
56 changes: 1 addition & 55 deletions mmcv/ops/pixel_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,6 @@
ext_module = ext_loader.load_ext('_ext', ['pixel_group'])


def estimate_confidence(label: torch.Tensor, score: torch.Tensor,
label_num: int) -> List[List[float]]:

import torch_npu
point_vector = torch.zeros((label_num, 2),
dtype=torch.float32).to(score.device)

label_flat = label.flatten()
score_flat = score.flatten()

mask = label_flat > 0
valid_labels = label_flat[mask]
valid_scores = score_flat[mask]

point_vector.index_add_(
0, valid_labels,
torch.stack((valid_scores, torch.ones_like(valid_scores)), dim=1))

valid_mask = point_vector[:, 1] > 0
point_vector[valid_mask, 0] /= point_vector[valid_mask, 1]

point_vector_list = point_vector.tolist()
for l in range(1, label_num):
coords = (label == l).nonzero(as_tuple=False).float()
coords = coords[:, [1, 0]]
point_vector_list[l].extend(coords.flatten().tolist())

return point_vector_list


def pixel_group(
score: Union[np.ndarray, Tensor],
mask: Union[np.ndarray, Tensor],
Expand Down Expand Up @@ -89,30 +59,6 @@ def pixel_group(
if isinstance(kernel_contour, np.ndarray):
kernel_contour = torch.from_numpy(kernel_contour)

if score.device.type == 'npu':
import torch_npu
embedding_dim = embedding.shape[2]
kernel_vector = torch.zeros((kernel_region_num, embedding_dim),
dtype=torch.float32).to(score.device)

for label in range(1, kernel_region_num):
label_mask = (kernel_label == label)
label_embeddings = embedding[label_mask]
kernel_vector[label, :] = label_embeddings.sum(dim=0)
vector_sum = label_mask.sum()
kernel_vector[label, :] /= vector_sum

kernel_cv = kernel_vector[label, :]
valid_mask = (mask == 1) & (kernel_label == 0)
valid_embeddings = embedding[valid_mask]
distances = torch.sum((valid_embeddings - kernel_cv)**2, dim=1)
within_threshold = distances < distance_threshold**2

kernel_label[valid_mask] = torch.where(within_threshold, label,
kernel_label[valid_mask])

return estimate_confidence(kernel_label, score, kernel_region_num)

if torch.__version__ == 'parrots':
label = ext_module.pixel_group(
score,
Expand All @@ -137,4 +83,4 @@ def pixel_group(
kernel_label, kernel_contour,
kernel_region_num,
distance_threshold)
return pixel_assignment
return pixel_assignment

0 comments on commit 2ed1460

Please sign in to comment.