Skip to content

Commit

Permalink
!5851 [自研][推理引擎 AscendIE]CenterNet模型适配Torch-AIE #1
Browse files Browse the repository at this point in the history
* [CenterNet] 适配用例 #1
  • Loading branch information
陈楚未 authored and 杨博 committed Dec 1, 2023
1 parent 27c42b5 commit a1f876b
Show file tree
Hide file tree
Showing 11 changed files with 934 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os
import argparse
import torch
import torch_aie
from torch_aie import _enums
import torchvision

def export_torch_aie(model_path, batch_size, soc_version, save_path, device_id):
trace_model = torch.jit.load(model_path)
trace_model.eval()
input_info = [torch_aie.Input((batch_size, 3, 512, 512))]
torch_aie.set_device(device_id)
torchaie_model = torch_aie.compile(
trace_model,
inputs=input_info,
allow_tensor_replace_int = True,
torch_executed_ops = [],
precision_policy=torch_aie.PrecisionPolicy.FP32,
soc_version=soc_version,
)
suffix = os.path.splitext(model_path)[-1]
saved_name = os.path.basename(model_path).split('.')[0] + f"b{batch_size}_torch_aie" + suffix
torchaie_model.save(os.path.join(save_path, saved_name))
print("[INFO] torch_aie compile for CenterNet finished, model saved in: ", os.path.join(save_path, saved_name))

def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--torch-script-path', type=str, default='./CenterNet_torchscript.pt', help='trace model path')
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--save-path', type=str, default='./', help='compiled model path')
parser.add_argument('--soc-version', type=str, default='Ascend310P3', help='soc version')
parser.add_argument('--device-id', type=int, default=0, help='device id')
opt_args = parser.parse_args()
return opt_args

def main():
print("[INFO] torch_aie compile for CenterNet start")
opt_args = parse_opt()
export_torch_aie(opt_args.torch_script_path, opt_args.batch_size, opt_args.soc_version, opt_args.save_path, opt_args.device_id)

if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time
import os
import copy

import torch
import torch_aie
import numpy as np
from tqdm import tqdm


def parse_arguments():
parser = argparse.ArgumentParser(description="inference")
parser.add_argument("--aie-module-path", default="./CenterNet_torchscriptb1_torch_aie.pt")
parser.add_argument("--batch-size", type=int, default=1)
parser.add_argument("--processed-dataset-path", default="./prep_dataset/")
parser.add_argument("--output-save-path", default="./result_aie/")
parser.add_argument("--model-input-height", type=int, default=512, help="input tensor height")
parser.add_argument("--model-input-width", type=int, default=512, help="input tensor width")
parser.add_argument("--device-id", type=int, default=0, help="device id")
parser.add_argument("--warmup-count", type=int, default=5, help="warmup count")
parser.add_argument("--output-num", type=int, default=3, help="output num")
return parser.parse_args()


def load_aie_module(args):
torch_aie.set_device(args.device_id)
aie_module = torch.jit.load(args.aie_module_path)
aie_module.eval()
return aie_module


def get_total_files(args):
file_paths = []
for root, _, files in os.walk(args.processed_dataset_path):
for file in files:
file_paths.append(os.path.join(root, file))
file_names = [os.path.basename(file_path) for file_path in file_paths]
total_files = len(file_paths)
return total_files, file_paths, file_names


def generate_batches(args, total_files, file_paths, file_names):
batch_size = args.batch_size
total_batches = (total_files + batch_size - 1) // batch_size
padding = total_batches * batch_size - total_files

for batch_num in range(total_batches):
batch_data = []
batch_file_names = []
for item in range(batch_size):
index = batch_size * batch_num + item
if index == total_files:
break
batch_data.append(
torch.from_numpy(np.fromfile(file_paths[index], np.float32)).view(
[1, 3, args.model_input_height, args.model_input_width]
)
)
batch_file_names.append(file_names[index])
index += 1
if (batch_num == (total_batches - 1)) and (padding > 0):
for _ in range(padding):
batch_data.append(copy.deepcopy(batch_data[-1]))
batch_file_names.append(file_names[-1])
yield torch.cat(batch_data).to(f"npu:{args.device_id}"), batch_file_names


def main():
print("[INFO] CenterNet Torch-AIE inference process start")

# Parse user input arguments
args = parse_arguments()
if not os.path.exists(args.output_save_path):
os.makedirs(args.output_save_path)

# Load AIE module
aie_module = load_aie_module(args)

# Generate input data according to batch size
total_files, file_paths, file_names = get_total_files(args)
data_generator = generate_batches(args, total_files, file_paths, file_names)

# Start inference
inference_time = []
stream = torch_aie.npu.Stream(f"npu:{args.device_id}")

for count, (input_tensor, batched_file_name) in enumerate(tqdm(data_generator, total=total_files), start=1):
input_tensor = input_tensor.to(f"npu:{args.device_id}")
with torch_aie.npu.stream(stream):
start_time = time.time()
aie_result = aie_module(input_tensor)
stream.synchronize()
cost = time.time() - start_time
# Warm-up using 5 steps by default
if count >= args.warmup_count:
inference_time.append(cost)

for i, file_name in enumerate(batched_file_name):
file_name = file_name.split('.')[0]

for j in range(args.output_num):
aie_result_j = aie_result[j].to("cpu")
aie_result_j[i].numpy().tofile(f'{args.output_save_path}{file_name}_{j}.bin')

# Calculate inference avg cost and throughput
aie_avg_cost = sum(inference_time) / len(inference_time) * 1000
aie_throughput = args.batch_size / (sum(inference_time) / len(inference_time))

print(f'\n[INFO] Torch-AIE inference avg cost (batch={args.batch_size}): {aie_avg_cost} ms/pic')
print(f'[INFO] Throughput = {aie_throughput} pic/s')
print('[INFO] CenterNet Torch-AIE inference process finished')


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import sys
import argparse
from glob import glob

import torch
import numpy as np
import cv2
from tqdm import tqdm

ROOT = './CenterNet/src/'
if ROOT not in sys.path:
sys.path.append(ROOT) # add ROOT to PATH

from lib.opts import opts
from lib.detectors.detector_factory import detector_factory
from lib.datasets.dataset_factory import get_dataset
from lib.models.decode import ctdet_decode
from lib.utils.post_process import ctdet_post_process
from lib.models.model import create_model, load_model
import lib.datasets.dataset.coco


def post_process(dets, meta, scale=1):
num_classes=80
dets = dets.detach().cpu().numpy()
dets = dets.reshape(1, -1, dets.shape[2])
dets = ctdet_post_process(
dets.copy(), [meta['c']], [meta['s']],
meta['out_height'], meta['out_width'], 80)
for j in range(1, 81):
dets[0][j] = np.array(dets[0][j], dtype=np.float32).reshape(-1, 5)
dets[0][j][:, :4] /= scale
return dets[0]


def merge_outputs(detections):
results = {}
for j in range(1, 80 + 1):
results[j] = np.concatenate(
[detection[j] for detection in detections], axis=0).astype(np.float32)
return results


def run(result_list, index, meta, filenames):
output={}
for i in range(3):
buf = np.fromfile(f'{result_list}/{filenames[0:-4]}_{i}.bin', dtype="float32")
if i == 0:
output['hm'] = torch.tensor(buf.reshape(1, 80, 128, 128))
if i == 1:
output['wh'] = torch.tensor(buf.reshape(1, 2, 128, 128))
if i == 2:
output['reg'] = torch.tensor(buf.reshape(1, 2, 128, 128))
detections = []
hm = output['hm'].sigmoid_()
wh = output['wh']
reg = output['reg']
detss = ctdet_decode(hm, wh, reg)
dets = post_process(detss, meta)
detections.append(dets)
results = merge_outputs(detections)
return results


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='CenterNet')
parser.add_argument('--bin-data-path', default='./result_aie', type=str, help='Torch-AIE infer result path')
parser.add_argument('--dataset', default='/data/datasets', type=str, help='COCO dataset path')
parser.add_argument('--save-dir', default='./postprocessed', type=str, help='postprocessed files save path')
args = parser.parse_args()

if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)

new_datapath = args.dataset
opt = opts().parse('{} --load_model {}'.format('ctdet', './ctdet_coco_dla_2x.pth').split(' '))
Dataset = get_dataset(opt.dataset, opt.task)
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
opt.data_dir = new_datapath
Detector = detector_factory[opt.task]
dataset = Dataset(opt, 'val')
opt.gpus[0] = -1
detector = Detector(opt)
filename = []
num_iters = len(dataset)

for ind in tqdm(range(num_iters)):
img_id = dataset.images[ind]
img_info = dataset.coco.loadImgs(ids=[img_id])[0]
img_path = os.path.join(dataset.img_dir, img_info['file_name'])
image = cv2.imread(img_path)
images, metas = detector.pre_process(image, 1.0, meta=None)
ret = run(args.bin_data_path, ind, metas, img_info['file_name'])
np.savez(os.path.join(args.save_dir, str(img_id)), dic=ret)
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import argparse
from glob import glob

import torch
import numpy as np
import cv2
from tqdm import tqdm

ROOT = './CenterNet/src/'
if ROOT not in sys.path:
sys.path.append(ROOT)

from lib.opts import opts
from lib.detectors.detector_factory import detector_factory
from lib.datasets.dataset_factory import get_dataset
from lib.models.decode import ctdet_decode
from lib.utils.post_process import ctdet_post_process
from lib.models.model import create_model, load_model
import lib.datasets.dataset.coco

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='CenterNet')
parser.add_argument('--dataset', default='/data/datasets', type=str, help='dataset')
parser.add_argument('--resultfolder', default='./run_eval_result', type=str, help='Dir to save results')
parser.add_argument('--postprocessed_dir', default='./postprocessed', type=str, help='Dir that contains postprocessed results')
args = parser.parse_args()

new_datapath = args.dataset
if not os.path.exists(args.resultfolder):
os.makedirs(args.resultfolder)
opt = opts().parse('{} --load_model {}'.format('ctdet', './ctdet_coco_dla_2x.pth').split(' '))
Dataset = get_dataset(opt.dataset, opt.task)
opt = opts().update_dataset_info_and_set_heads(opt, Dataset)
opt.data_dir = new_datapath
dataset = Dataset(opt, 'val')
opt.gpus[0] = -1
results = {}
num_iters = len(dataset)

for ind in tqdm(range(num_iters)):
img_id = dataset.images[ind]
ret = np.load(os.path.join(args.postprocessed_dir,str(img_id)+'.npz'),allow_pickle=True)['dic'].tolist()
results[img_id] = ret
dataset.run_eval(results, args.resultfolder)
Loading

0 comments on commit a1f876b

Please sign in to comment.