-
Notifications
You must be signed in to change notification settings - Fork 33
/
run_detector.py
124 lines (97 loc) · 4.69 KB
/
run_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# FootAndBall: Integrated Player and Ball Detector
# Jacek Komorowski, Grzegorz Kurzejamski, Grzegorz Sarwas
# Copyright (c) 2020 Sport Algorithmics and Gaming
#
# Run FootAndBall detector on ISSIA-CNR Soccer videos
#
import torch
import cv2
import os
import argparse
import tqdm
import network.footandball as footandball
import data.augmentation as augmentations
from data.augmentation import PLAYER_LABEL, BALL_LABEL
def draw_bboxes(image, detections):
font = cv2.FONT_HERSHEY_SIMPLEX
for box, label, score in zip(detections['boxes'], detections['labels'], detections['scores']):
if label == PLAYER_LABEL:
x1, y1, x2, y2 = box
color = (255, 0, 0)
cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
cv2.putText(image, '{:0.2f}'.format(score), (int(x1), max(0, int(y1)-10)), font, 1, color, 2)
elif label == BALL_LABEL:
x1, y1, x2, y2 = box
x = int((x1 + x2) / 2)
y = int((y1 + y2) / 2)
color = (0, 0, 255)
radius = 25
cv2.circle(image, (int(x), int(y)), radius, color, 2)
cv2.putText(image, '{:0.2f}'.format(score), (max(0, int(x - radius)), max(0, (y - radius - 10))), font, 1,
color, 2)
return image
def run_detector(model: footandball.FootAndBall, args: argparse.Namespace):
model.print_summary(show_architecture=False)
model = model.to(args.device)
_, file_name = os.path.split(args.path)
if args.device == 'cpu':
print('Loading CPU weights...')
state_dict = torch.load(args.weights, map_location=lambda storage, loc: storage)
else:
print('Loading GPU weights...')
state_dict = torch.load(args.weights)
model.load_state_dict(state_dict)
# Set model to evaluation mode
model.eval()
sequence = cv2.VideoCapture(args.path)
fps = sequence.get(cv2.CAP_PROP_FPS)
(frame_width, frame_height) = (int(sequence.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(sequence.get(cv2.CAP_PROP_FRAME_HEIGHT)))
n_frames = int(sequence.get(cv2.CAP_PROP_FRAME_COUNT))
out_sequence = cv2.VideoWriter(args.out_video, cv2.VideoWriter_fourcc(*'XVID'), fps,
(frame_width, frame_height))
print('Processing video: {}'.format(args.path))
pbar = tqdm.tqdm(total=n_frames)
while sequence.isOpened():
ret, frame = sequence.read()
if not ret:
# End of video
break
# Convert color space from BGR to RGB, convert to tensor and normalize
img_tensor = augmentations.numpy2tensor(frame)
with torch.no_grad():
# Add dimension for the batch size
img_tensor = img_tensor.unsqueeze(dim=0).to(args.device)
detections = model(img_tensor)[0]
frame = draw_bboxes(frame, detections)
out_sequence.write(frame)
pbar.update(1)
pbar.close()
sequence.release()
out_sequence.release()
if __name__ == '__main__':
print('Run FootAndBall detector on input video')
# Train the DeepBall ball detector model
parser = argparse.ArgumentParser()
parser.add_argument('--path', help='path to video', type=str, required=True)
parser.add_argument('--model', help='model name', type=str, default='fb1')
parser.add_argument('--weights', help='path to model weights', type=str, required=True)
parser.add_argument('--ball_threshold', help='ball confidence detection threshold', type=float, default=0.7)
parser.add_argument('--player_threshold', help='player confidence detection threshold', type=float, default=0.7)
parser.add_argument('--out_video', help='path to video with detection results', type=str, required=True,
default=None)
parser.add_argument('--device', help='device (CPU or CUDA)', type=str, default='cuda:0')
args = parser.parse_args()
print('Video path: {}'.format(args.path))
print('Model: {}'.format(args.model))
print('Model weights path: {}'.format(args.weights))
print('Ball confidence detection threshold [0..1]: {}'.format(args.ball_threshold))
print('Player confidence detection threshold [0..1]: {}'.format(args.player_threshold))
print('Output video path: {}'.format(args.out_video))
print('Device: {}'.format(args.device))
print('')
assert os.path.exists(args.weights), 'Cannot find FootAndBall model weights: {}'.format(args.weights)
assert os.path.exists(args.path), 'Cannot open video: {}'.format(args.path)
model = footandball.model_factory(args.model, 'detect', ball_threshold=args.ball_threshold,
player_threshold=args.player_threshold)
run_detector(model, args)