From d87565f4879c90d22838c1dd7b7daed1e20f6a8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E6=A5=9A=E6=9C=AA?= Date: Tue, 5 Dec 2023 02:45:25 +0000 Subject: [PATCH] =?UTF-8?q?!5881=20[=E8=87=AA=E7=A0=94][=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E5=BC=95=E6=93=8E=20AscendIE]SOLOv2=E6=A8=A1=E5=9E=8B=E9=80=82?= =?UTF-8?q?=E9=85=8DTorch-AIE=20#1=20*=20[SOLOv2]=20=E9=80=82=E9=85=8D?= =?UTF-8?q?=E7=94=A8=E4=BE=8B=20#1=20*=20remove=20solov2=20*=20SOLOv2?= =?UTF-8?q?=E6=A8=A1=E5=9E=8BTorch-AIE=E9=80=82=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../built-in/cv/segmentation/SOLOV2/README.md | 288 ++++++++++++++++++ .../cv/segmentation/SOLOV2/SOLOV2.diff | 280 +++++++++++++++++ .../cv/segmentation/SOLOV2/requirements.txt | 11 + .../cv/segmentation/SOLOV2/solov2_get_info.py | 75 +++++ .../segmentation/SOLOV2/solov2_inference.py | 83 +++++ .../segmentation/SOLOV2/solov2_postprocess.py | 113 +++++++ .../segmentation/SOLOV2/solov2_preprocess.py | 84 +++++ .../SOLOV2/solov2_pth2torchscript.py | 38 +++ 8 files changed, 972 insertions(+) create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/README.md create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/SOLOV2.diff create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/requirements.txt create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_get_info.py create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_inference.py create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_postprocess.py create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_preprocess.py create mode 100644 AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_pth2torchscript.py diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/README.md b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/README.md new file mode 100644 index 000000000..e57a48186 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/README.md @@ -0,0 +1,288 @@ +# SOLOV2模型-基于推理引擎PyTorch框架插件的部署及推理指导 + +- [概述](#ZH-CN_TOPIC_0000001172161501) + +- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) + +- [快速上手](#ZH-CN_TOPIC_0000001126281700) + + - [获取源码](#section4622531142816) + - [准备数据集](#section183221994411) + - [模型推理](#section741711594517) + +- [模型推理性能](#ZH-CN_TOPIC_0000001172201573) + +- [配套环境](#ZH-CN_TOPIC_0000001126121892) + + ****** + + +# 概述 + +SOLOV2模型是一个box-free的实例分割模型。SOLOV2相对SOLOV1的主要改动有两点,一是通过一个有效的整体实例掩码表示方案来实现,该方案动态地分割图像中的每个实例,而不需要使用边界盒检测。 具体来说,目标掩码的生成(Mask generation)分解为掩码核预测(Mask kernel prediction)和掩码特征学习(Mask feature learning),分别负责生成卷积核和待卷积的特征映射。二是SOLOV2通过我们的新矩阵显著减少了推理开销非最大抑制(NMS)技术。 + + +- 参考实现: + + ``` + url=https://github.com/WXinlong/SOLO + branch=master + commit_id=95f3732d5fbb0d7c7044c7dd074f439d48a72ce5 + ``` + +## 输入输出数据 + +- 输入数据 + + | 输入数据 | 数据类型 | 大小 | 数据排布格式 | + | -------- | -------- | -------------------------- | ------------ | + | input | RGB_FP32 | batchsize x 3 x 800 x 1216 | NCHW | + + +- 输出数据 + + | 输出数据 | 大小 | 数据类型 | 数据排布格式 | + | -------- | --------------- | -------- | ------------ | + | output1 | 100 x 200 x 304 | FLOAT32 | ND | + | output2 | 100 | INT32 | ND | + | output3 | 100 | FLOAT32 | ND | + + +# 推理环境准备\[所有版本\] + +- 该模型需要以下插件与驱动 + + **表 1** 版本配套表 + +| 配套 | 版本 | 环境准备指导 | +| --------------------------------------------------------------- | ------- | ----------------------------------------------------------------------------------------------------- | +| 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | +| CANN | 7.0.RC1.alpha003 | - | +| Python | 3.9.11 | - | +| PyTorch | 2.0.1 | - | +| torch_aie | 6.3.rc2 | - | +| 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | + +# 快速上手 + +## 获取源码 + +1. 获取开源代码仓。 + + 在已下载的源码包根目录下,执行如下命令。 + + ``` + git clone https://github.com/WXinlong/SOLO.git -b master + cd SOLO + git reset --hard 95f3732d5fbb0d7c7044c7dd074f439d48a72ce5 + cd .. + ``` + +2. 安装依赖。 + + ``` + apt-get install libjpeg-dev zlig1g-dev + pip install -r requirements.txt + ``` + + 其中mmcv安装建议参考[官方安装指导说明](https://mmcv.readthedocs.io/zh_CN/latest/get_started/installation.html#pip) + ``` + # Linux cpu torch2.0.x mmcv2.1.0 (Linux环境安装指令) + pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cpu/torch2.0/index.html + ``` + + 其中mmdet需要用以下方式安装。 + + ``` + cd SOLO + patch -p1 < ../MMDET.diff + patch -p1 < ../SOLOV2.diff + pip install -v -e . + cd .. + ``` + +## 准备数据集 + +1. 获取原始数据集。(解压命令参考tar –xvf \*.tar与 unzip \*.zip) + + 本模型需要coco2017数据集,数据集下载地址https://cocodataset.org/ + + 请将val2017图片及其标注文件放入服务器/data/datasets/coco/文件夹,val2017目录存放coco数据集的验证集图片,annotations目录存放coco数据集的instances_val2017.json,文件目录结构如下: + + ``` + ├──root + └──dataset + └──coco + └──annotations + └──val2017 + ``` + + +2. 数据预处理。 + + ``` + python3 solov2_preprocess.py \ + --image_src_path=/data/datasets/coco/val2017 \ + --bin_file_path=val2017_bin \ + --meta_file_path=val2017_bin_meta \ + --model_input_height=800 \ + --model_input_width=1216 + ``` + + - --image_src_path:数据集路径 + - --bin_file_path:生成的图片bin文件路径 + - --meta_file_path:生成的图片附加信息路径(临时信息,get_info.py需要用到) + + 每个图像对应生成一个二进制bin文件,一个附加信息文件。 + +3. 生成数据集info文件。 + + 执行“get_info.py”,会生成“solov2_meta.info”用于后处理。 + + ``` + python3 solov2_get_info.py /data/datasets/coco/ SOLO/configs/solov2/solov2_r50_fpn_8gpu_1x.py val2017_bin val2017_bin_meta solov2.info solov2_meta.info 1216 800 + ``` + + * “/data/datasets/coco/”:数据集路径。 + + * “SOLO/configs/solo/solo_r50_fpn_8gpu_1x.py”:模型配置文件。 + + * “val2017_bin”:预处理后的数据文件的**相对路径**。 + + * “val2017_bin_meta”:预处理后的数据文件的**相对路径**。 + + * solo.info:生成的数据集文件保存的路径。 + + * solo2_meta.info:生成的数据集文件保存的路径。 + + * “1216”:图片宽。 + + * “800”:图片高。 + + 运行成功后,在当前目录中生成“solov2_meta.info”。 + + +## 模型推理 + +### 1. 模型转换 + + 使用PyTorch将模型权重文件.pth转换为torchscript文件 + + 1. 获取权重文件。 + + 权重文件:[SOLOv2_R50_1x.pth](https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/1_PyTorch_PTH/SOLOV2/PTH/SOLOv2_R50_1x.pth),请将其放在与“solov2_pth2torchscript.py”文件同一目 + + 2. 导出torchscript文件 + + ```shell + python3 solov2_pth2torchscript.py \ + --config SOLO/configs/solov2/solov2_r50_fpn_8gpu_1x.py \ + --pth-path SOLOv2_R50_1x.pth \ + --shape 800 1216 + ``` + + 获得solov2.torchscript.pt文件。 + + + 参数说明 + + `--config`:模型配置文件路径 + + `--pth-path`:PTH权重文件路径 + + `--shape`:模型输入shape + + 3. 配置环境变量。 + + ```shell + source /usr/local/Ascend/ascend-toolkit/set_env.sh + export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64/driver/:${LD_LIBRARY_PATH} + ``` + + > **说明:** + >该脚本中环境变量仅供参考,请以实际安装环境配置环境变量。详细介绍请参见《[CANN 开发辅助工具指南 \(推理\)](https://support.huawei.com/enterprise/zh/ascend-computing/cann-pid-251168373?category=developer-documents&subcategory=auxiliary-development-tools)》。 + + 4. 执行命令查看芯片名称($\{chip\_name\})。 + + ```shell + npu-smi info + #该设备芯片名为Ascend310P3 (在下一步中赋值给soc_version环境变量) + 回显如下: + +-------------------+-----------------+------------------------------------------------------+ + | NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page) | + | Chip Device | Bus-Id | AICore(%) Memory-Usage(MB) | + +===================+=================+======================================================+ + | 0 310P3 | OK | 15.8 42 0 / 0 | + | 0 0 | 0000:82:00.0 | 0 1074 / 21534 | + +===================+=================+======================================================+ + | 1 310P3 | OK | 15.4 43 0 / 0 | + | 0 1 | 0000:89:00.0 | 0 1070 / 21534 | + +===================+=================+======================================================+ + ``` + + 5. 对原生ts文件执行torch_aie编译,导出NPU支持的ts文件 + + ```shell + soc_version="Ascend310P3" # User-defined + python solov2_export_torch_aie_ts.py \ + --torch-script-path ./solov2_torchscript.pt \ + --batch-size 1 \ + --save-path ./ \ + --soc-version ${soc_version} + ``` + + + 参数说明 + + `--torch-script-path`:原生ts文件路径 + + `--batch-size`:用户自定义的batch size + + `--save-path`:AIE编译后的ts文件保存路径 + + `--soc-version`:NPU型号 + + 运行成功后生成solov2_torchscriptb1_torch_aie.pt模型文件。 + +### 2. 执行推理并验证精度与性能 + + 1. 执行推理 + + 推理完成后将输出模型推理性能结果 + + ```shell + python solov2_inference.py \ + --aie-module-path ./solov2_torchscriptb1_torch_aie.pt \ + --batch-size 1 \ + --processed-dataset-path ./val2017_bin/ \ + --output-save-path ./result_aie/ \ + --device-id 0 + ``` + + + 参数说明: + + --aie-module-path: AIE编译后模型的路径 + + --batch-size: 模型输入的BatchSize + + --processed-dataset-path:经预处理COCO数据集的路径 + + --output-save-path:推理结果保存路径 + + --device-id: Ascend NPU ID(可通过npu-smi info查看) + + 2. 数据后处理 + + 处理完成后将输出模型推理精度结果 + + ```shell + python solov2_postprocess.py \ + --dataset_path /data/datasets/coco/ \ + --model_config SOLO/configs/solov2/solov2_r50_fpn_8gpu_1x.py \ + --bin_data_path result_aie \ + --meta_info solov2_meta.info \ + --net_out_num 3 \ + --model_input_height 800 \ + --model_input_width 1216 + ``` + +# 模型推理性能&精度 + +基于推理引擎完成推理计算,精度与性能可参考下列数据: + +| Soc version | Batch Size | Dataset | Accuracy | Performance | +| ---------- | ---------- | ---------- | ---------- | ---------- | +| Ascend910A | 1 | coco2017 | Average Precision(IoU=0.50:0.95): 0.340 | 7.21 fps | + + +# FAQ +1. 若遇到类似报错:ImportError: /lib/aarch64-linux-gnu/libGLdispatch.so.0: cannot allocate memory in static TLS block + + 解决方法: + export LD_PRELOAD=$LD_PRELOAD:{报错信息中的路径} \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/SOLOV2.diff b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/SOLOV2.diff new file mode 100644 index 000000000..33308bcc5 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/SOLOV2.diff @@ -0,0 +1,280 @@ +diff --git a/mmdet/core/post_processing/matrix_nms.py b/mmdet/core/post_processing/matrix_nms.py +index cbbe420..764d9cb 100644 +--- a/mmdet/core/post_processing/matrix_nms.py ++++ b/mmdet/core/post_processing/matrix_nms.py +@@ -1,6 +1,17 @@ + import torch + + ++def triu_(x, diagonal=0): ++ t = x.shape[0] ++ base = torch.arange(t, device=x.device) ++ mask = base.expand(t, t) ++ base = base.unsqueeze(-1) ++ if diagonal: ++ base = base + diagonal ++ mask = mask >= base ++ return mask * x ++ ++ + def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0, sum_masks=None): + """Matrix NMS for multi-class masks. + +@@ -26,10 +37,12 @@ def matrix_nms(seg_masks, cate_labels, cate_scores, kernel='gaussian', sigma=2.0 + # union. + sum_masks_x = sum_masks.expand(n_samples, n_samples) + # iou. +- iou_matrix = (inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix)).triu(diagonal=1) ++ iou_matrix = inter_matrix / (sum_masks_x + sum_masks_x.transpose(1, 0) - inter_matrix) ++ iou_matrix = triu_(iou_matrix, diagonal=1) + # label_specific matrix. + cate_labels_x = cate_labels.expand(n_samples, n_samples) +- label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float().triu(diagonal=1) ++ label_matrix = (cate_labels_x == cate_labels_x.transpose(1, 0)).float() ++ label_matrix = triu_(label_matrix, diagonal=1) + + # IoU compensation + compensate_iou, _ = (iou_matrix * label_matrix).max(0) +diff --git a/mmdet/models/anchor_heads/solov2_head.py b/mmdet/models/anchor_heads/solov2_head.py +index 2765eb2..5142cbd 100644 +--- a/mmdet/models/anchor_heads/solov2_head.py ++++ b/mmdet/models/anchor_heads/solov2_head.py +@@ -26,8 +26,8 @@ def points_nms(heat, kernel=2): + # kernel must be 2 + hmax = nn.functional.max_pool2d( + heat, (kernel, kernel), stride=1, padding=1) +- keep = (hmax[:, :, :-1, :-1] == heat).float() +- return heat * keep ++ keep = torch.abs(hmax[:, :, :-1, :-1] - heat) < 1e-3 ++ return keep.int() + + def dice_loss(input, target): + input = input.contiguous().view(input.size()[0], -1) +@@ -150,8 +150,13 @@ class SOLOv2Head(nn.Module): + ins_kernel_feat = x + # ins branch + # concat coord +- x_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-1], device=ins_kernel_feat.device) +- y_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-2], device=ins_kernel_feat.device) ++ feat_h, feat_w = ins_kernel_feat.shape[-2], ins_kernel_feat.shape[-1] ++ feat_h, feat_w = int(feat_h.cpu().numpy() if isinstance(feat_h, torch.Tensor) else feat_h), \ ++ int(feat_w.cpu().numpy() if isinstance(feat_w, torch.Tensor) else feat_w) ++ step_x = 2. / (feat_w - 1) ++ step_y = 2. / (feat_h - 1) ++ x_range = torch.arange(-1, 1.00147, step_x, device=ins_kernel_feat.device) ++ y_range = torch.arange(-1, 1.00147, step_y, device=ins_kernel_feat.device) + y, x = torch.meshgrid(y_range, x_range) + y = y.expand([ins_kernel_feat.shape[0], 1, -1, -1]) + x = x.expand([ins_kernel_feat.shape[0], 1, -1, -1]) +@@ -177,7 +182,9 @@ class SOLOv2Head(nn.Module): + cate_pred = self.solo_cate(cate_feat) + + if eval: +- cate_pred = points_nms(cate_pred.sigmoid(), kernel=2).permute(0, 2, 3, 1) ++ cate_mark = points_nms(cate_pred, kernel=2) ++ cate_pred = cate_pred.sigmoid() ++ cate_pred = (cate_pred * cate_mark).permute(0, 2, 3, 1) + return cate_pred, kernel_pred + + def loss(self, +@@ -355,12 +362,13 @@ class SOLOv2Head(nn.Module): + grid_order_list.append(grid_order) + return ins_label_list, cate_label_list, ins_ind_label_list, grid_order_list + +- def get_seg(self, cate_preds, kernel_preds, seg_pred, img_metas, cfg, rescale=None): ++ def get_seg(self, cate_preds, kernel_preds, seg_pred, cfg, rescale=None): + num_levels = len(cate_preds) + featmap_size = seg_pred.size()[-2:] ++ img_num = 1 + + result_list = [] +- for img_id in range(len(img_metas)): ++ for img_id in range(img_num): + cate_pred_list = [ + cate_preds[i][img_id].view(-1, self.cate_out_channels).detach() for i in range(num_levels) + ] +@@ -369,15 +377,11 @@ class SOLOv2Head(nn.Module): + kernel_preds[i][img_id].permute(1, 2, 0).view(-1, self.kernel_out_channels).detach() + for i in range(num_levels) + ] +- img_shape = img_metas[img_id]['img_shape'] +- scale_factor = img_metas[img_id]['scale_factor'] +- ori_shape = img_metas[img_id]['ori_shape'] + + cate_pred_list = torch.cat(cate_pred_list, dim=0) + kernel_pred_list = torch.cat(kernel_pred_list, dim=0) + +- result = self.get_seg_single(cate_pred_list, seg_pred_list, kernel_pred_list, +- featmap_size, img_shape, ori_shape, scale_factor, cfg, rescale) ++ result = self.get_seg_single(cate_pred_list, seg_pred_list, kernel_pred_list, featmap_size, cfg, rescale) + result_list.append(result) + return result_list + +@@ -386,28 +390,17 @@ class SOLOv2Head(nn.Module): + seg_preds, + kernel_preds, + featmap_size, +- img_shape, +- ori_shape, +- scale_factor, + cfg, + rescale=False, debug=False): + + assert len(cate_preds) == len(kernel_preds) + +- # overall info. +- h, w, _ = img_shape +- upsampled_size_out = (featmap_size[0] * 4, featmap_size[1] * 4) +- + # process. +- inds = (cate_preds > cfg.score_thr) +- cate_scores = cate_preds[inds] +- if len(cate_scores) == 0: +- return None ++ cate_scores, cate_preds = torch.max(cate_preds, dim=-1) ++ cate_scores, inds = torch.topk(cate_scores, k=200) + +- # cate_labels & kernel_preds +- inds = inds.nonzero() +- cate_labels = inds[:, 1] +- kernel_preds = kernel_preds[inds[:, 0]] ++ cate_labels = cate_preds[inds].int() ++ kernel_preds = kernel_preds[inds] + + # trans vector. + size_trans = cate_labels.new_tensor(self.seg_num_grids).pow(2).cumsum(0) +@@ -417,33 +410,37 @@ class SOLOv2Head(nn.Module): + strides[:size_trans[0]] *= self.strides[0] + for ind_ in range(1, n_stage): + strides[size_trans[ind_-1]:size_trans[ind_]] *= self.strides[ind_] +- strides = strides[inds[:, 0]] ++ strides = strides[inds] + + # mask encoding. + I, N = kernel_preds.shape + kernel_preds = kernel_preds.view(I, N, 1, 1) +- seg_preds = F.conv2d(seg_preds, kernel_preds, stride=1).squeeze(0).sigmoid() ++ seg_preds_shape = seg_preds.shape ++ seg_preds = seg_preds.view(seg_preds_shape[0], seg_preds_shape[1], -1) ++ kernel_preds = kernel_preds.view(kernel_preds.shape[0], kernel_preds.shape[1]) ++ seg_preds = torch.matmul(kernel_preds, seg_preds) ++ seg_preds = seg_preds.view(seg_preds_shape[0], I, seg_preds_shape[2], seg_preds_shape[3]).squeeze(0).sigmoid() + # mask. + seg_masks = seg_preds > cfg.mask_thr +- sum_masks = seg_masks.sum((1, 2)).float() ++ sum_masks = seg_masks.int().sum((1, 2)).float() + + # filter. + keep = sum_masks > strides +- if keep.sum() == 0: +- return None +- +- seg_masks = seg_masks[keep, ...] +- seg_preds = seg_preds[keep, ...] +- sum_masks = sum_masks[keep] +- cate_scores = cate_scores[keep] +- cate_labels = cate_labels[keep] ++ keep_int = keep.int() ++ keep_mask = keep_int.reshape(-1, 1, 1) ++ keep_mask = keep_mask.expand(-1, seg_masks.shape[1], seg_masks.shape[2]).int() ++ seg_masks = torch.mul(seg_masks, keep_mask) ++ seg_preds = torch.mul(seg_preds, keep_mask) ++ cate_scores = torch.mul(cate_scores, keep_int) ++ sum_masks = torch.mul(sum_masks, keep_int) ++ sum_masks += 0.1 + + # maskness. + seg_scores = (seg_preds * seg_masks.float()).sum((1, 2)) / sum_masks + cate_scores *= seg_scores + + # sort and keep top nms_pre +- sort_inds = torch.argsort(cate_scores, descending=True) ++ _, sort_inds = torch.sort(cate_scores, descending=True) + if len(sort_inds) > cfg.nms_pre: + sort_inds = sort_inds[:cfg.nms_pre] + seg_masks = seg_masks[sort_inds, :, :] +@@ -456,27 +453,12 @@ class SOLOv2Head(nn.Module): + cate_scores = matrix_nms(seg_masks, cate_labels, cate_scores, + kernel=cfg.kernel,sigma=cfg.sigma, sum_masks=sum_masks) + +- # filter. +- keep = cate_scores >= cfg.update_thr +- if keep.sum() == 0: +- return None +- seg_preds = seg_preds[keep, :, :] +- cate_scores = cate_scores[keep] +- cate_labels = cate_labels[keep] +- + # sort and keep top_k +- sort_inds = torch.argsort(cate_scores, descending=True) ++ _, sort_inds = torch.sort(cate_scores, descending=True) + if len(sort_inds) > cfg.max_per_img: + sort_inds = sort_inds[:cfg.max_per_img] + seg_preds = seg_preds[sort_inds, :, :] + cate_scores = cate_scores[sort_inds] + cate_labels = cate_labels[sort_inds] + +- seg_preds = F.interpolate(seg_preds.unsqueeze(0), +- size=upsampled_size_out, +- mode='bilinear')[:, :, :h, :w] +- seg_masks = F.interpolate(seg_preds, +- size=ori_shape[:2], +- mode='bilinear').squeeze(0) +- seg_masks = seg_masks > cfg.mask_thr +- return seg_masks, cate_labels, cate_scores ++ return seg_preds, cate_labels, cate_scores +diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py +index 82f91bd..4a93a27 100644 +--- a/mmdet/models/detectors/base.py ++++ b/mmdet/models/detectors/base.py +@@ -124,7 +124,7 @@ class BaseDetector(nn.Module, metaclass=ABCMeta): + assert imgs_per_gpu == 1 + + if num_augs == 1: +- return self.simple_test(imgs[0], img_metas[0], **kwargs) ++ return self.simple_test(imgs[0], **kwargs) + else: + return self.aug_test(imgs, img_metas, **kwargs) + +diff --git a/mmdet/models/detectors/single_stage_ins.py b/mmdet/models/detectors/single_stage_ins.py +index 773d5d2..aa12e7e 100644 +--- a/mmdet/models/detectors/single_stage_ins.py ++++ b/mmdet/models/detectors/single_stage_ins.py +@@ -78,7 +78,7 @@ class SingleStageInsDetector(BaseDetector): + *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + return losses + +- def simple_test(self, img, img_meta, rescale=False): ++ def simple_test(self, img, img_meta=None, rescale=False): + x = self.extract_feat(img) + outs = self.bbox_head(x, eval=True) + +@@ -86,7 +86,7 @@ class SingleStageInsDetector(BaseDetector): + mask_feat_pred = self.mask_feat_head( + x[self.mask_feat_head. + start_level:self.mask_feat_head.end_level + 1]) +- seg_inputs = outs + (mask_feat_pred, img_meta, self.test_cfg, rescale) ++ seg_inputs = outs + (mask_feat_pred, self.test_cfg, rescale) + else: + seg_inputs = outs + (img_meta, self.test_cfg, rescale) + seg_result = self.bbox_head.get_seg(*seg_inputs) +diff --git a/mmdet/models/mask_heads/mask_feat_head.py b/mmdet/models/mask_heads/mask_feat_head.py +index 980b4ad..2e8504e 100644 +--- a/mmdet/models/mask_heads/mask_feat_head.py ++++ b/mmdet/models/mask_heads/mask_feat_head.py +@@ -105,8 +105,13 @@ class MaskFeatHead(nn.Module): + input_p = inputs[i] + if i == 3: + input_feat = input_p +- x_range = torch.linspace(-1, 1, input_feat.shape[-1], device=input_feat.device) +- y_range = torch.linspace(-1, 1, input_feat.shape[-2], device=input_feat.device) ++ feat_h, feat_w = input_feat.shape[-2], input_feat.shape[-1] ++ feat_h, feat_w = int(feat_h.cpu().numpy() if isinstance(feat_h, torch.Tensor) else feat_h), \ ++ int(feat_w.cpu().numpy() if isinstance(feat_w, torch.Tensor) else feat_w) ++ step_x = 2. / (feat_w - 1) ++ step_y = 2. / (feat_h - 1) ++ x_range = torch.arange(-1, 1.00147, step_x, device=input_feat.device) ++ y_range = torch.arange(-1, 1.00147, step_y, device=input_feat.device) + y, x = torch.meshgrid(y_range, x_range) + y = y.expand([input_feat.shape[0], 1, -1, -1]) + x = x.expand([input_feat.shape[0], 1, -1, -1]) diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/requirements.txt b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/requirements.txt new file mode 100644 index 000000000..772b1a39d --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/requirements.txt @@ -0,0 +1,11 @@ +torch == 2.0.1 +torchvision == 0.15.2 +numpy == 1.23.5 +Cython == 0.29.33 +Opencv-python == 4.8.1.78 +pycocotools == 2.0.7 +Pytest-runner == 5.3.1 +protobuf == 3.20.2 +decorator == 5.1.1 +sympy == 1.12 +tqdm == 4.66.1 \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_get_info.py b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_get_info.py new file mode 100644 index 000000000..8778f3cd3 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_get_info.py @@ -0,0 +1,75 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import mmcv +import pickle as pk +from mmdet.datasets import build_dataset + + +def load_config(config_path, image_src_path, ann_file, img_prefix): + cfg = mmcv.Config.fromfile(config_path) + cfg.data.test.ann_file = image_src_path + ann_file + cfg.data.test.img_prefix = image_src_path + img_prefix + return cfg + + +def write_info(dataset, bin_path, width, height, info_name): + with open(info_name, "w") as fp1: + for idx in range(5000): + img_id = dataset.img_ids[idx] + fp1.write("{} {}/{:0>12d}.bin {} {}\n".format(idx, bin_path, img_id, width, height)) + + +def write_meta_info(dataset, meta_path, info_meta_name): + with open(info_meta_name, "w") as fp2: + for idx in range(5000): + img_id = dataset.img_ids[idx] + with open("%s/%012d.pk" % (meta_path, img_id), "rb") as fp_meta: + meta = pk.load(fp_meta) + fp2.write("{} {}/{:0>12d}.bin {} {} {} {}\n".format( + idx, + meta_path, + img_id, + meta['img_shape'][1], + meta['img_shape'][0], + meta['ori_shape'][1], + meta['ori_shape'][0] + )) + + +def main(): + image_src_path = sys.argv[1] + config_path = sys.argv[2] + bin_path = sys.argv[3] + meta_path = sys.argv[4] + info_name = sys.argv[5] + info_meta_name = sys.argv[6] + width = int(sys.argv[7]) + height = int(sys.argv[8]) + + ann_file = '/annotations/instances_val2017.json' + img_prefix = '/val2017/' + + cfg = load_config(config_path, image_src_path, ann_file, img_prefix) + dataset = build_dataset(cfg.data.test) + + write_info(dataset, bin_path, width, height, info_name) + write_meta_info(dataset, meta_path, info_meta_name) + + print("Get info done!") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_inference.py b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_inference.py new file mode 100644 index 000000000..ab8744fa2 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_inference.py @@ -0,0 +1,83 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import time +import os + +import torch +import torch_aie +import numpy as np +from tqdm import tqdm + + +def parse_arguments(): + parser = argparse.ArgumentParser(description="inference") + parser.add_argument("--aie-module-path", default="./solov2_torchscriptb1_torch_aie.pt") + parser.add_argument("--batch-size", default=1) + parser.add_argument("--processed-dataset-path", default="./val2017_bin/") + parser.add_argument("--output-save-path", default="./result_aie/") + parser.add_argument("--model-input-height", type=int, default=800, help="input tensor height") + parser.add_argument("--model-input-width", type=int, default=1216, help="input tensor width") + parser.add_argument("--device-id", type=int, default=0, help="device id") + return parser.parse_args() + + +def load_aie_module(args): + torch_aie.set_device(args.device_id) + aie_module = torch.jit.load(args.aie_module_path) + aie_module.eval() + return aie_module + + +def main(): + # Parse user input arguments + args = parse_arguments() + if not os.path.exists(args.output_save_path): + os.makedirs(args.output_save_path) + + # Load AIE module + aie_module = load_aie_module(args) + + # Start inference + inference_time = [] + stream = torch_aie.npu.Stream(f"npu:{args.device_id}") + for idx, filename in enumerate(tqdm(os.listdir(args.processed_dataset_path))): + file_name = os.path.splitext(filename)[0] + input_tensor = torch.from_numpy(np.fromfile(args.processed_dataset_path + filename, dtype="float32")).view(1, 3, args.model_input_height, args.model_input_width) + input_tensor = input_tensor.to(f"npu:{args.device_id}") + with torch_aie.npu.stream(stream): + start_time = time.time() + aie_result = aie_module(input_tensor) + stream.synchronize() + cost = time.time() - start_time + # Warm-up using 5 steps + if idx >= 5: + inference_time.append(cost) + for i, tensor in enumerate(aie_result): + tensor = tensor.to("cpu") + tensor.numpy().tofile(f'{args.output_save_path}{file_name}_{i}.bin') + + # Calculate inference avg cost and throughput + aie_avg_cost = sum(inference_time) / len(inference_time) * 1000 + aie_throughput = args.batch_size / (sum(inference_time) / len(inference_time)) + + print(f'\n[INFO] Torch-AIE inference avg cost (batch={args.batch_size}): {aie_avg_cost} ms/pic') + print(f'[INFO] Throughput = {aie_throughput} pic/s') + print('[INFO] SOLOv2 Torch-AIE inference process finished') + + +if __name__ == "__main__": + print("[INFO] SOLOV2 Torch-AIE inference process start") + main() diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_postprocess.py b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_postprocess.py new file mode 100644 index 000000000..b4a6c3b56 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_postprocess.py @@ -0,0 +1,113 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mmcv +import numpy as np +import argparse +import torch +import torch.nn.functional as F +import pycocotools.mask as mask_util +from mmdet.core import coco_eval, results2json, results2json_segm +from mmdet.datasets import build_dataset +import os.path as osp +from tqdm import tqdm + + +ann_file = 'annotations/instances_val2017.json' +img_prefix = 'val2017/' + + +def get_masks(result, num_classes=80): + for cur_result in result: + masks = [[] for _ in range(num_classes)] + if cur_result is None: + return masks + seg_pred = cur_result[0].astype(np.uint8) + cate_label = cur_result[1].astype(np.int) + cate_score = cur_result[2].astype(np.float) + num_ins = seg_pred.shape[0] + for idx in range(num_ins): + cur_mask = seg_pred[idx, ...] + rle = mask_util.encode( + np.array(cur_mask[:, :, np.newaxis], order='F'))[0] + rst = (rle, cate_score[idx]) + masks[cate_label[idx]].append(rst) + return masks + + +def handle_seg(seg, img_shape, ori_shape, input_shape=(800, 1216), mask_thr=0.5): + seg = torch.tensor(seg) + h, w, = img_shape + pad_left = (input_shape[1] - w) // 2 + pad_top = (input_shape[0] - h) // 2 + seg = F.interpolate(seg.unsqueeze(0), + size=input_shape, + mode='bilinear')[:, :, pad_top:pad_top + h, pad_left:pad_left + w] + + seg = F.interpolate(seg, + size=ori_shape[:2], + mode='bilinear').squeeze(0) + seg = seg > mask_thr + return seg.numpy() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--dataset_path') + parser.add_argument('--model_config') + parser.add_argument('--bin_data_path') + parser.add_argument('--meta_info') + parser.add_argument('--net_out_num', type=int) + parser.add_argument("--model_input_height", type=int, + help='input tensor height') + parser.add_argument("--model_input_width", type=int, + help='input tensor width') + + args = parser.parse_args() + + cfg = mmcv.Config.fromfile(args.model_config) + cfg.data.test.test_mode = True + cfg.data.test.ann_file = args.dataset_path + ann_file + cfg.data.test.img_prefix = args.dataset_path + img_prefix + dataset = build_dataset(cfg.data.test) + num_classes = len(dataset.CLASSES) + + results = [] + + with open(args.meta_info, "r") as fp: + for line in tqdm(fp): + _, file_path, img_w, img_h, ori_w, ori_h = line.split() + img_w = int(img_w) + img_h = int(img_h) + ori_w = int(ori_w) + ori_h = int(ori_h) + file_name = file_path.split("/")[1].replace(".bin", "") + file_name = osp.join(args.bin_data_path, file_name) + result = [] + for idx in range(args.net_out_num): + if idx == 1: + result.append(np.fromfile( + f"{file_name}_{idx}.bin", dtype=np.int32)) + else: + result.append(np.fromfile( + f"{file_name}_{idx}.bin", dtype=np.float32)) + result[0].shape = (100, args.model_input_height // + 4, args.model_input_width // 4) + result[0] = handle_seg(result[0], (img_h, img_w), (ori_h, ori_w), + (args.model_input_height, args.model_input_width)) + result = get_masks([result], num_classes) + results.append(result) + + result_files = results2json_segm(dataset, results, "results_solo.pkl") + coco_eval(result_files, ["segm"], dataset.coco) diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_preprocess.py b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_preprocess.py new file mode 100644 index 000000000..f37714a58 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_preprocess.py @@ -0,0 +1,84 @@ +# Copyright 2023 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import numpy as np +import cv2 +import mmcv +import torch +import pickle as pk +import multiprocessing +from tqdm import tqdm + +flags = None + +def resize(img, size): + old_h = img.shape[0] + old_w = img.shape[1] + scale_ratio = min(size[0] / old_w, size[1] / old_h) + new_w = int(np.floor(old_w * scale_ratio)) + new_h = int(np.floor(old_h * scale_ratio)) + resized_img = mmcv.imresize(img, (new_w, new_h)) + return resized_img, scale_ratio + + +def gen_input_bin(file_batches, batch): + for file in file_batches[batch]: + + image = mmcv.imread(os.path.join(flags.image_src_path, file)) + ori_shape = image.shape + image, scale_factor = resize(image, (flags.model_input_width, flags.model_input_height)) + img_shape = image.shape + mean = np.array([123.675, 116.28, 103.53], dtype=np.float32) + std = np.array([58.395, 57.12, 57.375], dtype=np.float32) + image = mmcv.imnormalize(image, mean, std) + h = image.shape[0] + w = image.shape[1] + pad_left = (flags.model_input_width - w) // 2 + pad_top = (flags.model_input_height - h) // 2 + pad_right = flags.model_input_width - pad_left - w + pad_bottom = flags.model_input_height - pad_top - h + image = cv2.copyMakeBorder(image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0) + image = image.transpose(2, 0, 1) + image.tofile(os.path.join(flags.bin_file_path, file.split('.')[0] + ".bin")) + image_meta = {'img_shape': img_shape, 'scale_factor': scale_factor, 'ori_shape': ori_shape} + with open(os.path.join(flags.meta_file_path, file.split('.')[0] + ".pk"), "wb") as fp: + pk.dump(image_meta, fp) + + +def preprocess(): + files = os.listdir(flags.image_src_path) + file_batches = [files[i:i + 100] for i in range(0, 5000, 100) if files[i:i + 100] != []] + thread_pool = multiprocessing.Pool(len(file_batches)) + for batch in tqdm(range(len(file_batches))): + thread_pool.apply_async(gen_input_bin, args=(file_batches, batch)) + thread_pool.close() + thread_pool.join() + print("in thread, except will not report! please ensure bin files generated.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='preprocess of SOLOV2 model') + parser.add_argument("--image_src_path", default="/root/datasets/coco/val2017", help='image of dataset') + parser.add_argument("--bin_file_path", default="val2017_bin", help='Preprocessed image buffer') + parser.add_argument("--meta_file_path", default="val2017_bin_meta", help='Get image meta') + parser.add_argument("--model_input_height", default=800, type=int, help='input tensor height') + parser.add_argument("--model_input_width", default=1216, type=int, help='input tensor width') + flags = parser.parse_args() + if not os.path.exists(flags.bin_file_path): + os.makedirs(flags.bin_file_path) + if not os.path.exists(flags.meta_file_path): + os.makedirs(flags.meta_file_path) + preprocess() diff --git a/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_pth2torchscript.py b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_pth2torchscript.py new file mode 100644 index 000000000..49786da59 --- /dev/null +++ b/AscendIE/TorchAIE/built-in/cv/segmentation/SOLOV2/solov2_pth2torchscript.py @@ -0,0 +1,38 @@ +# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import argparse +import numpy as np +from mmdet.apis import init_detector + +input_names = ['input'] +output_names = ['seg_preds', 'cate_labels', 'cate_scores'] + +def pth2torchscript(args, fake_input): + model = init_detector(args.config, args.pth_path, device='cpu') + model.forward = model.simple_test + ts_model = torch.jit.trace(model, fake_input) + ts_model.save(args.out_path) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', help='model config') + parser.add_argument('--out-path', default='./solov2_torchscript.pt', help='onnx output name') + parser.add_argument('--pth-path', help='model pth path') + parser.add_argument('--shape', type=int, nargs='+', help='input image size hxw') + args = parser.parse_args() + assert len(args.shape) == 2 + fake_input = torch.randn(1, 3, args.shape[0], args.shape[1]) + pth2torchscript(args, fake_input)