From 84b003d324d1d5cbefc17d6e1cbf12a0680ea638 Mon Sep 17 00:00:00 2001 From: zaki1003 Date: Sun, 12 Mar 2023 20:53:38 +0100 Subject: [PATCH 1/3] Adding video prediction --- video_demo.py | 158 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 video_demo.py diff --git a/video_demo.py b/video_demo.py new file mode 100644 index 0000000..78a62cb --- /dev/null +++ b/video_demo.py @@ -0,0 +1,158 @@ +import argparse +import datetime +import random +import time +from pathlib import Path + +import torch +import torchvision.transforms as standard_transforms +import numpy as np + +from PIL import Image +import cv2 +from crowd_datasets import build_dataset +from engine import * +from models import build_model +import os +import imutils +import warnings +warnings.filterwarnings('ignore') + +def get_args_parser(): + parser = argparse.ArgumentParser('Set parameters for P2PNet evaluation', add_help=False) + + # * Backbone + parser.add_argument('--backbone', default='vgg16_bn', type=str, + help="name of the convolutional backbone to use") + + parser.add_argument('--row', default=2, type=int, + help="row number of anchor points") + parser.add_argument('--line', default=2, type=int, + help="line number of anchor points") + + parser.add_argument('--output_dir', default='', + help='path where to save') + parser.add_argument('--weight_path', default='', + help='path where the trained weights saved') + parser.add_argument('--video_path', default='', + help='path where of video') + parser.add_argument('--gpu_id', default=0, type=int, help='the gpu used for evaluation') + + return parser + +def main(args, debug=False): + + os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id) + print(args) + device = torch.device('cuda') + # get the P2PNet + model = build_model(args) + + # move to GPU + model.to(device) + + + # load trained model + #using Args + """ + if args.weight_path is not None: + checkpoint = torch.load(args.weight_path, map_location='cpu') + model.load_state_dict(checkpoint['model']) + """ + #Loading file directly + checkpoint = torch.load(Path('./weights/SHTechA.pth'), map_location='cpu') + model.load_state_dict(checkpoint['model']) + + + # convert to eval mode + model.eval() + # create the pre-processing transform + transform = standard_transforms.Compose([ + standard_transforms.ToTensor(), + standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ]) + + '''in video''' + fourcc = cv2.VideoWriter_fourcc(*'XVID') + + + + cap = cv2.VideoCapture('./fair.mp4') + ret, frame = cap.read() + print(frame.shape) + + '''out video''' + + scale_factor = 0.4 + + width = frame.shape[1] #output size + height = frame.shape[0] #output size + out = cv2.VideoWriter('./demo.avi', fourcc, 30, (1280, 1280)) + + while True: + try: + ret, frame = cap.read() + + + new_width = width // 128 * 128 + new_height = height // 128 * 128 + + scale_factor = 0.4 + frame = cv2.resize(frame, (0, 0), fx=scale_factor, fy=scale_factor) + img_raw= frame.copy() + ori_img = frame.copy() + + + except: + print("Test End") + cap.release() + break + + + + frame = frame.copy() + + # pre-proccessing + img = transform(frame) + samples = torch.Tensor(img).unsqueeze(0) + samples = samples.to(device) + + with torch.no_grad(): + + # run inference + outputs = model(samples) + outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0] + + outputs_points = outputs['pred_points'][0] + + threshold = 0.5 + # filter the predictions + points = outputs_points[outputs_scores > threshold].detach().cpu().numpy().tolist() + predict_cnt = int((outputs_scores > threshold).sum()) + + outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0] + + + outputs_points = outputs['pred_points'][0] + print("Count: ",predict_cnt) + + # draw the predictions + size = 2 + + for p in points: + img_to_draw = cv2.circle(img_raw , (int(p[0]), int(p[1])), size, (0, 0, 255), -1) + res = np.vstack((ori_img, img_to_draw)) + + cv2.putText(res, "Count:" + str(predict_cnt), (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + # save the visualized image + cv2.imwrite('./demo.jpg', res) + '''write in out_video''' + res = cv2.resize(res, (1280,1280)) + out.write(res) + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser('P2PNet evaluation script', parents=[get_args_parser()]) + args = parser.parse_args() + main(args) From 5f604f3e45e6a5eac008d1c95433ff8767458a5b Mon Sep 17 00:00:00 2001 From: zaki1003 <65148928+zaki1003@users.noreply.github.com> Date: Sun, 12 Mar 2023 20:21:29 +0100 Subject: [PATCH 2/3] Adding video prediction --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index fb798b1..beab86b 100644 --- a/README.md +++ b/README.md @@ -136,9 +136,17 @@ By default, a periodic evaluation will be conducted on the validation set. A trained model (with an MAE of **51.96**) on SHTechPartA is available at "./weights", run the following commands to launch a visualization demo: +### Testing on an image ``` CUDA_VISIBLE_DEVICES=0 python run_test.py --weight_path ./weights/SHTechA.pth --output_dir ./logs/ ``` +### Testing on a video +``` +CUDA_VISIBLE_DEVICES=0 python video_demo.py --weight_path ./weights/SHTechA.pth +``` +#### A demo of crowd counting on a video +![demo](https://user-images.githubusercontent.com/65148928/224567663-2434449f-fad4-44cb-8806-1ec0cfe518fc.gif) + ## Acknowledgements From 6784edb95782e1597ee4fc55e02a787bde5ad79a Mon Sep 17 00:00:00 2001 From: zaki1003 <65148928+zaki1003@users.noreply.github.com> Date: Tue, 21 Mar 2023 12:03:51 +0100 Subject: [PATCH 3/3] Showing the video when counting start --- video_demo.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/video_demo.py b/video_demo.py index 78a62cb..ce5cb21 100644 --- a/video_demo.py +++ b/video_demo.py @@ -149,6 +149,11 @@ def main(args, debug=False): '''write in out_video''' res = cv2.resize(res, (1280,1280)) out.write(res) + + cv2.putText(img_to_draw, "Count:" + str(predict_cnt), (30, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) + cv2.imshow("dst",img_to_draw) + if cv2.waitKey(1) & 0xFF == ord('q'): + break