diff --git a/src/icatcher/__init__.py b/src/icatcher/__init__.py index e7d4644..4b3b01f 100644 --- a/src/icatcher/__init__.py +++ b/src/icatcher/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.2.0" version = __version__ ### define classes -classes = {"noface": -2, "nobabyface": -1, "away": 0, "left": 1, "right": 2} +classes = {"none": -3, "noface": -2, "nobabyface": -1, "away": 0, "left": 1, "right": 2} reverse_classes = {v: k for k, v in classes.items()} ### imports from . import ( diff --git a/src/icatcher/cli.py b/src/icatcher/cli.py index 5a1e52d..3424e72 100644 --- a/src/icatcher/cli.py +++ b/src/icatcher/cli.py @@ -296,10 +296,10 @@ def predict_from_video(opt): :return: """ # initialize - loc = ( - -5 - ) # where in the sliding window to take the prediction (should be a function of opt.sliding_window_size) - cursor = -5 # points to the frame we will write to output relative to current frame + # loc determines where in the sliding window to take the prediction, fixed to be the middle frame + loc = -((opt.sliding_window_size // 2) + 1) + # cursor points to the frame we will write to output relative to current frame, it can change based on illegal transitions + cursor = -((opt.sliding_window_size // 2) + 1) logging.debug( "using the following values for per-channel mean: {}".format( opt.per_channel_mean @@ -434,7 +434,7 @@ def predict_from_video(opt): confidences.append(-1) image = np.zeros((1, opt.image_size, opt.image_size, 3), np.float64) my_box = np.array([0, 0, 0, 0, 0]) - image_sequence.append((image, True)) + image_sequence.append((image, False)) box_sequence.append(my_box) bbox_sequence.append(None) from_tracker.append(False) @@ -455,39 +455,37 @@ def predict_from_video(opt): ) crop, my_box = extract_crop(frame, selected_bbox, opt) if selected_bbox is None: - answers.append( - classes["nobabyface"] - ) # if selecting face fails, treat as away and mark invalid + # if selecting face fails, treat as away and mark invalid + answers.append(classes["nobabyface"]) confidences.append(-1) image = np.zeros((1, opt.image_size, opt.image_size, 3), np.float64) my_box = np.array([0, 0, 0, 0, 0]) - image_sequence.append((image, True)) + image_sequence.append((image, False)) box_sequence.append(my_box) bbox_sequence.append(None) else: + # if face detector succeeds, treat as "none" (will be overwritten later) and mark valid if crop.size == 0: raise ValueError("crop size is 0, what just happend?") - answers.append( - classes["left"] - ) # if face detector succeeds, treat as left and mark valid + answers.append(classes["none"]) confidences.append(-1) - image_sequence.append((crop, False)) + image_sequence.append((crop, True)) box_sequence.append(my_box) bbox_sequence.append(selected_bbox) if not from_tracker[-1]: last_known_valid_bbox = selected_bbox.copy() - if ( - len(image_sequence) == opt.sliding_window_size - ): # we have enough frames for prediction, predict for middle frame + if frame_count + 1 >= np.abs(cursor): + # sets important variables to cursor location cur_frame = frames[cursor] cur_bbox = bbox_sequence[cursor] is_from_tracker = from_tracker[cursor] + if len(image_sequence) == opt.sliding_window_size: + # we have enough frames for prediction, predict for middle frame frames.pop(0) bbox_sequence.pop(0) from_tracker.pop(0) - if not image_sequence[opt.sliding_window_size // 2][ - 1 - ]: # if middle image is valid + if image_sequence[opt.sliding_window_size // 2][1]: + # if middle image is valid to_predict = { "imgs": torch.tensor( np.array([x[0] for x in image_sequence[0::2]]), @@ -509,10 +507,10 @@ def predict_from_video(opt): confidence, _ = torch.max(probs, 1) float32_conf = confidence.cpu().numpy()[0] int32_pred = prediction.cpu().numpy()[0] - answers[loc] = int32_pred # update answers for the middle frame - confidences[ - loc - ] = float32_conf # update confidences for the middle frame + # update answers for the middle frame + answers[loc] = int32_pred + # update confidences for the middle frame + confidences[loc] = float32_conf image_sequence.pop(0) box_sequence.pop(0) @@ -525,6 +523,8 @@ def predict_from_video(opt): illegal_transitions, corrected_transitions, ) + # report results at cursor + if frame_count + 1 >= np.abs(cursor): class_text = reverse_classes[answers[cursor]] if opt.mirror_annotation: if class_text == "left": @@ -533,62 +533,49 @@ def predict_from_video(opt): class_text = "left" if opt.on_off: class_text = "off" if class_text == "away" else "on" - if opt.output_video_path: - if is_from_tracker and opt.track_face: - rect_color = (0, 0, 255) - else: - rect_color = (0, 255, 0) - draw.prepare_frame( - cur_frame, - cur_bbox, - show_arrow=True, - rect_color=rect_color, - conf=confidences[cursor], - class_text=class_text, - frame_number=frame_count, - pic_in_pic=opt.pic_in_pic, - ) - video_output_file.write(cur_frame) - if opt.show_output: - if is_from_tracker and opt.track_face: - rect_color = (0, 0, 255) - else: - rect_color = (0, 255, 0) - draw.prepare_frame( - cur_frame, - cur_bbox, - show_arrow=True, - rect_color=rect_color, - conf=confidences[cursor], - class_text=class_text, - frame_number=frame_count, - pic_in_pic=opt.pic_in_pic, - ) - - cv2.imshow("frame", cur_frame) - if cv2.waitKey(1) & 0xFF == ord("q"): - break - # handle writing output to file - if opt.output_annotation: - if opt.output_format == "raw_output": - with open(prediction_output_file, "a", newline="") as f: - f.write( - "{}, {}, {:.02f}\n".format( - str(frame_count + cursor + 1), - class_text, - confidences[cursor], - ) - ) - logging.info( - "frame: {}, class: {}, confidence: {:.02f}, cur_fps: {:.02f}".format( - str(frame_count + cursor + 1), - class_text, - confidences[cursor], - cur_fps(), - ) + user_abort = handle_output( + opt, + is_from_tracker, + cur_frame, + cur_bbox, + confidences[cursor], + cursor, + class_text, + frame_count, + video_output_file, + prediction_output_file, + cur_fps, ) + if user_abort: + break ret_val, frame = cap.read() frame_count += 1 + if not user_abort: + for i in range( + opt.sliding_window_size - np.abs(cursor), opt.sliding_window_size - 1 + ): + # report for final left over frames + class_text = "none" + cur_frame = frames[i] + cur_bbox = bbox_sequence[i] + is_from_tracker = from_tracker[i] + user_abort = handle_output( + opt, + is_from_tracker, + cur_frame, + cur_bbox, + -1, + cursor, + class_text, + frame_count, + video_output_file, + prediction_output_file, + cur_fps, + ) + frame_count = frame_count + 1 + if user_abort: + break + # finished processing a video file, cleanup cleanup( video_output_file, @@ -602,6 +589,77 @@ def predict_from_video(opt): ) +def handle_output( + opt, + is_from_tracker, + cur_frame, + cur_bbox, + confidence, + cursor, + class_text, + frame_count, + video_output_file, + prediction_output_file, + cur_fps, +): + # utility function to handle output (video, live stream, annotations, logging, etc.) + if opt.output_video_path: + if is_from_tracker and opt.track_face: + rect_color = (0, 0, 255) + else: + rect_color = (0, 255, 0) + draw.prepare_frame( + cur_frame, + cur_bbox, + show_arrow=True, + rect_color=rect_color, + conf=confidence, + class_text=class_text, + frame_number=frame_count + cursor + 1, + pic_in_pic=opt.pic_in_pic, + ) + video_output_file.write(cur_frame) + if opt.show_output: + if is_from_tracker and opt.track_face: + rect_color = (0, 0, 255) + else: + rect_color = (0, 255, 0) + draw.prepare_frame( + cur_frame, + cur_bbox, + show_arrow=True, + rect_color=rect_color, + conf=confidence, + class_text=class_text, + frame_number=frame_count + cursor + 1, + pic_in_pic=opt.pic_in_pic, + ) + + cv2.imshow("frame", cur_frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + return True + # handle writing output to file + if opt.output_annotation: + if opt.output_format == "raw_output": + with open(prediction_output_file, "a", newline="") as f: + f.write( + "{}, {}, {:.02f}\n".format( + str(frame_count + cursor + 1), + class_text, + confidence, + ) + ) + logging.info( + "frame: {}, class: {}, confidence: {:.02f}, cur_fps: {:.02f}".format( + str(frame_count + cursor + 1), + class_text, + confidence, + cur_fps(), + ) + ) + return False + + def cleanup( video_output_file, prediction_output_file, @@ -612,6 +670,7 @@ def cleanup( cap, opt, ): + # saves and frees resources if opt.show_output: cv2.destroyAllWindows() if opt.output_video_path: diff --git a/src/icatcher/options.py b/src/icatcher/options.py index b863a51..b795974 100644 --- a/src/icatcher/options.py +++ b/src/icatcher/options.py @@ -242,6 +242,8 @@ def parse_arguments(my_string=None): raise ValueError( "On off mode can only be used with raw output format. Pass raw_output with the --output_format flag." ) + if args.sliding_window_size % 2 == 0: + raise ValueError("sliding_window_size must be odd.") if not args.per_channel_mean: args.per_channel_mean = [0.485, 0.456, 0.406] if not args.per_channel_std: diff --git a/tests/test_basic.py b/tests/test_basic.py index b619ac0..741b699 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -32,7 +32,7 @@ def test_process_video(): """ tests processing a video file. """ - arguments = "tests/test_data/test.mp4" + arguments = "tests/test_data/test_short.mp4" opt = icatcher.options.parse_arguments(arguments) source = Path(opt.source) ( @@ -57,37 +57,67 @@ def test_mask(): @pytest.mark.parametrize( - "args_string", + "args_string, result_file", [ - "tests/test_data/test.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --overwrite", - "tests/test_data/test.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data", - "tests/test_data/test.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --mirror_annotation --overwrite", - "tests/test_data/test.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --output_format compressed --overwrite", - "tests/test_data/test.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --mirror_annotation --output_format compressed --overwrite", + # ( + # "tests/test_data/test_long.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --overwrite", + # "tests/test_data/test_long_result.txt", + # ), + ( + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --overwrite", + "tests/test_data/test_short_result.txt", + ), + ( + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data", + "tests/test_data/test_short_result.txt", + ), + ( + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --illegal_transitions_path tests/test_data/illegal_transitions_short.csv --overwrite", + "tests/test_data/test_short_illegal_result.txt", + ), + ( + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --mirror_annotation --overwrite", + "tests/test_data/test_short_result.txt", + ), + ( + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --output_format compressed --overwrite", + "tests/test_data/test_short_result.txt", + ), + ( + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --fd_model opencv_dnn --output_annotation tests/test_data --mirror_annotation --output_format compressed --overwrite", + "tests/test_data/test_short_result.txt", + ), ], ) -def test_predict_from_video(args_string): +def test_predict_from_video(args_string, result_file): """ runs entire prediction pipeline with several command line options. + note this uses the original paper models which are faster but less accurate. + tests for the newer models is out of scope for this test. """ + result_file = Path(result_file) + with open(result_file, "r") as f: + gt_data = f.readlines() + gt_classes = [x.split(",")[1].strip() for x in gt_data] + gt_classes = np.array([icatcher.classes[x] for x in gt_classes]) + gt_confidences = np.array([float(x.split(",")[2].strip()) for x in gt_data]) args = icatcher.options.parse_arguments(args_string) if not args.overwrite: try: predict_from_video(args) - except ( - FileExistsError - ): # should be raised if overwrite is False and file exists, which is expected since this is the second test + except FileExistsError: + # should be raised if overwrite is False and file exists, which is expected since this is not the first test return else: predict_from_video(args) if args.output_annotation: if args.output_format == "compressed": - output_file = Path("tests/test_data/test.npz") + output_file = Path("tests/test_data/{}.npz".format(Path(args.source).stem)) data = np.load(output_file) predicted_classes = data["arr_0"] confidences = data["arr_1"] else: - output_file = Path("tests/test_data/test.txt") + output_file = Path("tests/test_data/{}.txt".format(Path(args.source).stem)) with open(output_file, "r") as f: data = f.readlines() predicted_classes = [x.split(",")[1].strip() for x in data] @@ -96,8 +126,15 @@ def test_predict_from_video(args_string): ) confidences = np.array([float(x.split(",")[2].strip()) for x in data]) assert len(predicted_classes) == len(confidences) - # assert len(predicted_classes) == 194 # hard coded number of frames in test video + assert len(predicted_classes) == len(gt_classes) if args.mirror_annotation: - assert (predicted_classes == 2).all() + modfied_predicted_classes = predicted_classes.copy() + # 999 is just a dummy value + modfied_predicted_classes[modfied_predicted_classes == 1] = 999 + modfied_predicted_classes[modfied_predicted_classes == 2] = 1 + modfied_predicted_classes[modfied_predicted_classes == 999] = 2 + assert (modfied_predicted_classes == gt_classes).all() + np.isclose(gt_confidences, confidences, 0.01).all() else: - assert (predicted_classes == 1).all() + assert (predicted_classes == gt_classes).all() + np.isclose(gt_confidences, confidences, 0.01).all() diff --git a/tests/test_data/illegal_transitions_short.csv b/tests/test_data/illegal_transitions_short.csv new file mode 100644 index 0000000..c457a81 --- /dev/null +++ b/tests/test_data/illegal_transitions_short.csv @@ -0,0 +1,2 @@ +illegal_transition,corrected_transition +"1111","2222" \ No newline at end of file diff --git a/tests/test_data/test.mp4 b/tests/test_data/test_short.mp4 similarity index 100% rename from tests/test_data/test.mp4 rename to tests/test_data/test_short.mp4 diff --git a/tests/test_data/test_short_illegal_result.txt b/tests/test_data/test_short_illegal_result.txt new file mode 100644 index 0000000..56090eb --- /dev/null +++ b/tests/test_data/test_short_illegal_result.txt @@ -0,0 +1,194 @@ +0, none, -1.00 +1, none, -1.00 +2, none, -1.00 +3, none, -1.00 +4, right, -1.00 +5, right, -1.00 +6, right, -1.00 +7, right, -1.00 +8, right, -1.00 +9, right, -1.00 +10, right, -1.00 +11, right, -1.00 +12, right, -1.00 +13, right, -1.00 +14, right, -1.00 +15, right, -1.00 +16, right, -1.00 +17, right, -1.00 +18, right, -1.00 +19, right, -1.00 +20, right, -1.00 +21, right, -1.00 +22, right, -1.00 +23, right, -1.00 +24, right, -1.00 +25, right, -1.00 +26, right, -1.00 +27, right, -1.00 +28, right, -1.00 +29, right, -1.00 +30, right, -1.00 +31, right, -1.00 +32, right, -1.00 +33, right, -1.00 +34, right, -1.00 +35, right, -1.00 +36, right, -1.00 +37, right, -1.00 +38, right, -1.00 +39, right, -1.00 +40, right, -1.00 +41, right, -1.00 +42, right, -1.00 +43, right, -1.00 +44, right, -1.00 +45, right, -1.00 +46, right, -1.00 +47, right, -1.00 +48, right, -1.00 +49, right, -1.00 +50, right, -1.00 +51, right, -1.00 +52, right, -1.00 +53, right, -1.00 +54, right, -1.00 +55, right, -1.00 +56, right, -1.00 +57, right, -1.00 +58, right, -1.00 +59, right, -1.00 +60, right, -1.00 +61, right, -1.00 +62, right, -1.00 +63, right, -1.00 +64, right, -1.00 +65, right, -1.00 +66, right, -1.00 +67, right, -1.00 +68, right, -1.00 +69, right, -1.00 +70, right, -1.00 +71, right, -1.00 +72, right, -1.00 +73, right, -1.00 +74, right, -1.00 +75, right, -1.00 +76, right, -1.00 +77, right, -1.00 +78, right, -1.00 +79, right, -1.00 +80, right, -1.00 +81, right, -1.00 +82, right, -1.00 +83, right, -1.00 +84, right, -1.00 +85, right, -1.00 +86, right, -1.00 +87, right, -1.00 +88, right, -1.00 +89, right, -1.00 +90, right, -1.00 +91, right, -1.00 +92, right, -1.00 +93, right, -1.00 +94, right, -1.00 +95, right, -1.00 +96, right, -1.00 +97, right, -1.00 +98, right, -1.00 +99, right, -1.00 +100, right, -1.00 +101, right, -1.00 +102, right, -1.00 +103, right, -1.00 +104, right, -1.00 +105, right, -1.00 +106, right, -1.00 +107, right, -1.00 +108, right, -1.00 +109, right, -1.00 +110, right, -1.00 +111, right, -1.00 +112, right, -1.00 +113, right, -1.00 +114, right, -1.00 +115, right, -1.00 +116, right, -1.00 +117, right, -1.00 +118, right, -1.00 +119, right, -1.00 +120, right, -1.00 +121, right, -1.00 +122, right, -1.00 +123, right, -1.00 +124, right, -1.00 +125, right, -1.00 +126, right, -1.00 +127, right, -1.00 +128, right, -1.00 +129, right, -1.00 +130, right, -1.00 +131, right, -1.00 +132, right, -1.00 +133, right, -1.00 +134, right, -1.00 +135, right, -1.00 +136, right, -1.00 +137, right, -1.00 +138, right, -1.00 +139, right, -1.00 +140, right, -1.00 +141, right, -1.00 +142, right, -1.00 +143, right, -1.00 +144, right, -1.00 +145, right, -1.00 +146, right, -1.00 +147, right, -1.00 +148, right, -1.00 +149, right, -1.00 +150, right, -1.00 +151, right, -1.00 +152, right, -1.00 +153, right, -1.00 +154, right, -1.00 +155, right, -1.00 +156, right, -1.00 +157, right, -1.00 +158, right, -1.00 +159, right, -1.00 +160, right, -1.00 +161, right, -1.00 +162, right, -1.00 +163, right, -1.00 +164, right, -1.00 +165, right, -1.00 +166, right, -1.00 +167, right, -1.00 +168, right, -1.00 +169, right, -1.00 +170, right, -1.00 +171, right, -1.00 +172, right, -1.00 +173, right, -1.00 +174, right, -1.00 +175, right, -1.00 +176, right, -1.00 +177, right, -1.00 +178, right, -1.00 +179, right, -1.00 +180, right, -1.00 +181, right, -1.00 +182, right, -1.00 +183, right, -1.00 +184, right, -1.00 +185, right, -1.00 +186, none, -1.00 +187, none, -1.00 +188, none, -1.00 +189, none, -1.00 +190, none, -1.00 +191, none, -1.00 +192, none, -1.00 +193, none, -1.00 diff --git a/tests/test_data/test_short_result.txt b/tests/test_data/test_short_result.txt new file mode 100644 index 0000000..eb1fd3e --- /dev/null +++ b/tests/test_data/test_short_result.txt @@ -0,0 +1,194 @@ +0, none, -1.00 +1, none, -1.00 +2, none, -1.00 +3, none, -1.00 +4, left, 0.84 +5, left, 0.92 +6, left, 0.95 +7, left, 0.96 +8, left, 0.96 +9, left, 0.97 +10, left, 0.96 +11, left, 0.98 +12, left, 0.99 +13, left, 0.99 +14, left, 1.00 +15, left, 1.00 +16, left, 1.00 +17, left, 1.00 +18, left, 1.00 +19, left, 1.00 +20, left, 1.00 +21, left, 1.00 +22, left, 1.00 +23, left, 1.00 +24, left, 1.00 +25, left, 1.00 +26, left, 1.00 +27, left, 1.00 +28, left, 1.00 +29, left, 1.00 +30, left, 1.00 +31, left, 1.00 +32, left, 1.00 +33, left, 1.00 +34, left, 1.00 +35, left, 1.00 +36, left, 1.00 +37, left, 1.00 +38, left, 1.00 +39, left, 1.00 +40, left, 1.00 +41, left, 1.00 +42, left, 1.00 +43, left, 1.00 +44, left, 1.00 +45, left, 1.00 +46, left, 1.00 +47, left, 1.00 +48, left, 1.00 +49, left, 1.00 +50, left, 1.00 +51, left, 1.00 +52, left, 1.00 +53, left, 1.00 +54, left, 1.00 +55, left, 1.00 +56, left, 1.00 +57, left, 1.00 +58, left, 1.00 +59, left, 1.00 +60, left, 1.00 +61, left, 1.00 +62, left, 1.00 +63, left, 1.00 +64, left, 0.99 +65, left, 0.99 +66, left, 0.94 +67, left, 0.99 +68, left, 0.98 +69, left, 0.99 +70, left, 0.91 +71, left, 0.96 +72, left, 0.93 +73, left, 0.98 +74, left, 0.98 +75, left, 1.00 +76, left, 1.00 +77, left, 1.00 +78, left, 0.99 +79, left, 1.00 +80, left, 1.00 +81, left, 1.00 +82, left, 1.00 +83, left, 1.00 +84, left, 1.00 +85, left, 1.00 +86, left, 1.00 +87, left, 1.00 +88, left, 1.00 +89, left, 1.00 +90, left, 1.00 +91, left, 1.00 +92, left, 1.00 +93, left, 1.00 +94, left, 1.00 +95, left, 1.00 +96, left, 1.00 +97, left, 1.00 +98, left, 1.00 +99, left, 1.00 +100, left, 1.00 +101, left, 1.00 +102, left, 0.99 +103, left, 1.00 +104, left, 1.00 +105, left, 1.00 +106, left, 1.00 +107, left, 1.00 +108, left, 0.99 +109, left, 1.00 +110, left, 0.99 +111, left, 0.99 +112, left, 0.98 +113, left, 0.98 +114, left, 0.98 +115, left, 0.98 +116, left, 0.97 +117, left, 0.97 +118, left, 0.94 +119, left, 0.95 +120, left, 0.96 +121, left, 0.96 +122, left, 0.97 +123, left, 0.97 +124, left, 0.99 +125, left, 0.98 +126, left, 0.99 +127, left, 0.99 +128, left, 0.99 +129, left, 0.99 +130, left, 1.00 +131, left, 0.99 +132, left, 0.99 +133, left, 0.99 +134, left, 0.98 +135, left, 0.99 +136, left, 0.98 +137, left, 0.99 +138, left, 0.98 +139, left, 0.99 +140, left, 0.97 +141, left, 0.97 +142, left, 0.98 +143, left, 0.98 +144, left, 0.98 +145, left, 0.99 +146, left, 0.99 +147, left, 0.98 +148, left, 0.98 +149, left, 0.98 +150, left, 0.98 +151, left, 0.98 +152, left, 0.98 +153, left, 0.98 +154, left, 0.98 +155, left, 0.99 +156, left, 0.99 +157, left, 0.99 +158, left, 0.99 +159, left, 1.00 +160, left, 1.00 +161, left, 1.00 +162, left, 1.00 +163, left, 1.00 +164, left, 1.00 +165, left, 1.00 +166, left, 1.00 +167, left, 0.99 +168, left, 1.00 +169, left, 0.99 +170, left, 0.99 +171, left, 0.99 +172, left, 0.99 +173, left, 0.99 +174, left, 1.00 +175, left, 0.99 +176, left, 1.00 +177, left, 0.99 +178, left, 1.00 +179, left, 0.99 +180, left, 1.00 +181, left, 1.00 +182, left, 1.00 +183, left, 1.00 +184, left, 1.00 +185, left, 1.00 +186, left, 1.00 +187, left, 1.00 +188, left, 1.00 +189, left, 1.00 +190, none, -1.00 +191, none, -1.00 +192, none, -1.00 +193, none, -1.00 diff --git a/tests/test_gaze_model.py b/tests/test_gaze_model.py index d3ff857..40895d8 100644 --- a/tests/test_gaze_model.py +++ b/tests/test_gaze_model.py @@ -7,13 +7,13 @@ @pytest.mark.parametrize( "args_string, model_class_name", [ - ("tests/test_data/test.mp4", "RegNet"), + ("tests/test_data/test_short.mp4", "RegNet"), ( - "tests/test_data/test.mp4 --model icatcher+_lookit_regnet.pth", + "tests/test_data/test_short.mp4 --model icatcher+_lookit_regnet.pth", "RegNet", ), ( - "tests/test_data/test.mp4 --model icatcher+_lookit.pth", + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth", "ResNet", ), ], @@ -38,8 +38,8 @@ def test_load_models(args_string, model_class_name): @pytest.mark.parametrize( "args_string", [ - "tests/test_data/test.mp4 --model icatcher+_lookit_regnet.pth --gpu_id=0", - "tests/test_data/test.mp4 --model icatcher+_lookit.pth --gpu_id=0", + "tests/test_data/test_short.mp4 --model icatcher+_lookit_regnet.pth --gpu_id=0", + "tests/test_data/test_short.mp4 --model icatcher+_lookit.pth --gpu_id=0", ], ) def test_predict_from_video(args_string):