-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmain.py
137 lines (109 loc) · 6.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import tempfile
from pathlib import Path
import numpy as np
import cv2 # opencv-python
from ultralytics import YOLO
import deep_sort.deep_sort.deep_sort as ds
def putTextWithBackground(img, text, origin, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, text_color=(255, 255, 255), bg_color=(0, 0, 0), thickness=1):
"""绘制带有背景的文本。
:param img: 输入图像。
:param text: 要绘制的文本。
:param origin: 文本的左上角坐标。
:param font: 字体类型。
:param font_scale: 字体大小。
:param text_color: 文本的颜色。
:param bg_color: 背景的颜色。
:param thickness: 文本的线条厚度。
"""
# 计算文本的尺寸
(text_width, text_height), _ = cv2.getTextSize(text, font, font_scale, thickness)
# 绘制背景矩形
bottom_left = origin
top_right = (origin[0] + text_width, origin[1] - text_height - 5) # 减去5以留出一些边距
cv2.rectangle(img, bottom_left, top_right, bg_color, -1)
# 在矩形上绘制文本
text_origin = (origin[0], origin[1] - 5) # 从左上角的位置减去5来留出一些边距
cv2.putText(img, text, text_origin, font, font_scale, text_color, thickness, lineType=cv2.LINE_AA)
def extract_detections(results, detect_class):
"""
从模型结果中提取和处理检测信息。
- results: YoloV8模型预测结果,包含检测到的物体的位置、类别和置信度等信息。
- detect_class: 需要提取的目标类别的索引。
参考: https://docs.ultralytics.com/modes/predict/#working-with-results
"""
# 初始化一个空的二维numpy数组,用于存放检测到的目标的位置信息
# 如果视频中没有需要提取的目标类别,如果不初始化,会导致tracker报错
detections = np.empty((0, 4))
confarray = [] # 初始化一个空列表,用于存放检测到的目标的置信度。
# 遍历检测结果
# 参考:https://docs.ultralytics.com/modes/predict/#working-with-results
for r in results:
for box in r.boxes:
# 如果检测到的目标类别与指定的目标类别相匹配,提取目标的位置信息和置信度
if box.cls[0].int() == detect_class:
x1, y1, x2, y2 = box.xywh[0].int().tolist() # 提取目标的位置信息,并从tensor转换为整数列表。
conf = round(box.conf[0].item(), 2) # 提取目标的置信度,从tensor中取出浮点数结果,并四舍五入到小数点后两位。
detections = np.vstack((detections, np.array([x1, y1, x2, y2]))) # 将目标的位置信息添加到detections数组中。
confarray.append(conf) # 将目标的置信度添加到confarray列表中。
return detections, confarray # 返回提取出的位置信息和置信度。
# 视频处理
def detect_and_track(input_path: str, output_path: str, detect_class: int, model, tracker) -> Path:
"""
处理视频,检测并跟踪目标。
- input_path: 输入视频文件的路径。
- output_path: 处理后视频保存的路径。
- detect_class: 需要检测和跟踪的目标类别的索引。
- model: 用于目标检测的模型。
- tracker: 用于目标跟踪的模型。
"""
cap = cv2.VideoCapture(input_path) # 使用OpenCV打开视频文件。
if not cap.isOpened(): # 检查视频文件是否成功打开。
print(f"Error opening video file {input_path}")
return None
fps = cap.get(cv2.CAP_PROP_FPS) # 获取视频的帧率
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) # 获取视频的分辨率(宽度和高度)。
output_video_path = Path(output_path) / "output.avi" # 设置输出视频的保存路径。
# 设置视频编码格式为XVID格式的avi文件
# 如果需要使用h264编码或者需要保存为其他格式,可能需要下载openh264-1.8.0
# 下载地址:https://github.com/cisco/openh264/releases/tag/v1.8.0
# 下载完成后将dll文件放在当前文件夹内
fourcc = cv2.VideoWriter_fourcc(*"XVID")
output_video = cv2.VideoWriter(output_video_path.as_posix(), fourcc, fps, size, isColor=True) # 创建一个VideoWriter对象用于写视频。
# 对每一帧图片进行读取和处理
while True:
success, frame = cap.read() # 逐帧读取视频。
# 如果读取失败(或者视频已处理完毕),则跳出循环。
if not (success):
break
# 使用YoloV8模型对当前帧进行目标检测。
results = model(frame, stream=True)
# 从预测结果中提取检测信息。
detections, confarray = extract_detections(results, detect_class)
# 使用deepsort模型对检测到的目标进行跟踪。
resultsTracker = tracker.update(detections, confarray, frame)
for x1, y1, x2, y2, Id in resultsTracker:
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2]) # 将位置信息转换为整数。
# 绘制bounding box和文本
cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 255), 3)
putTextWithBackground(frame, str(int(Id)), (max(-10, x1), max(40, y1)), font_scale=1.5, text_color=(255, 255, 255), bg_color=(255, 0, 255))
output_video.write(frame) # 将处理后的帧写入到输出视频文件中。
output_video.release() # 释放VideoWriter对象。
cap.release() # 释放视频文件。
print(f'output dir is: {output_video_path}')
return output_video_path
if __name__ == "__main__":
# 指定输入视频的路径。
######
input_path = "test.mp4"
######
# 输出文件夹,默认为系统的临时文件夹路径
output_path = tempfile.mkdtemp() # 创建一个临时目录用于存放输出视频。
# 加载yoloV8模型权重
model = YOLO("yolov8n.pt")
# 设置需要检测和跟踪的目标类别
# yoloV8官方模型的第一个类别为'person'
detect_class = 0
print(f"detecting {model.names[detect_class]}") # model.names返回模型所支持的所有物体类别
# 加载DeepSort模型
tracker = ds.DeepSort("deep_sort/deep_sort/deep/checkpoint/ckpt.t7")
detect_and_track(input_path, output_path, detect_class, model, tracker)