From e8dbba9fbc96eba37a6aab66e750f78fbc014e31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ce=20Ge=20=28=E6=88=88=E7=AD=96=29?= Date: Fri, 26 Jul 2024 17:48:33 +0800 Subject: [PATCH] feat: Enhance video_motion_score_filter capabilities (#361) - Support float-based sampling FPS for finer frame rate control - Resize frames before optical flow for better efficiency - Normalize optical flow magnitude relative to frame dimensions - Improve robustness in frame reading and resizing - Expand tests to cover new features --- configs/config_all.yaml | 3 + .../ops/filter/video_motion_score_filter.py | 97 +++++++++++++++---- data_juicer/utils/unittest_utils.py | 4 +- .../filter/test_video_motion_score_filter.py | 56 +++++++++-- 4 files changed, 134 insertions(+), 26 deletions(-) diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 41a0809a6..7a66292ce 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -406,6 +406,9 @@ process: min_score: 0.25 # the minimum motion score to keep samples max_score: 10000.0 # the maximum motion score to keep samples sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow + size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w) + max_size: null # maximum allowed for the longer edge of resized frames + relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length any_or_all: any # keep this sample when any/all videos meet the filter condition - video_nsfw_filter: # filter samples according to the nsfw scores of videos in them hf_nsfw_model: Falconsai/nsfw_image_detection # Huggingface model name for nsfw classification diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py index 91e9e03ac..76ddaf4fe 100644 --- a/data_juicer/ops/filter/video_motion_score_filter.py +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -1,8 +1,9 @@ import sys from contextlib import contextmanager +from typing import List, Optional, Sequence, Tuple, Union import numpy as np -from jsonargparse.typing import PositiveInt +from jsonargparse.typing import PositiveFloat, PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -44,7 +45,11 @@ class VideoMotionScoreFilter(Filter): def __init__(self, min_score: float = 0.25, max_score: float = sys.float_info.max, - sampling_fps: PositiveInt = 2, + sampling_fps: PositiveFloat = 2, + size: Optional[Union[PositiveInt, + Sequence[PositiveInt]]] = None, + max_size: Optional[PositiveInt] = None, + relative: bool = False, any_or_all: str = 'any', *args, **kwargs): @@ -53,8 +58,20 @@ def __init__(self, :param min_score: The minimum motion score to keep samples. :param max_score: The maximum motion score to keep samples. - :param sampling_fps: The samplig rate of frames_per_second to - compute optical flow. + :param sampling_fps: The sampling rate in frames_per_second for + optical flow calculations. + :param size: Resize frames before computing optical flow. If size is a + sequence like (h, w), frame size will be matched to this. If size + is an int, smaller edge of frames will be matched to this number. + i.e, if height > width, then frame will be rescaled to (size * + height / width, size). Default `None` to keep the original size. + :param max_size: The maximum allowed for the longer edge of resized + frames. If the longer edge of frames is greater than max_size after + being resized according to size, size will be overruled so that the + longer edge is equal to max_size. As a result, the smaller edge may + be shorter than size. This is only supported if size is an int. + :param relative: If `True`, the optical flow magnitude is normalized to + a [0, 1] range, relative to the frame's diagonal length. :param any_or_all: keep this sample with 'any' or 'all' strategy of all videos. 'any': keep this sample if any videos meet the condition. 'all': keep this sample only if all videos meet the @@ -67,6 +84,17 @@ def __init__(self, self.max_score = max_score self.sampling_fps = sampling_fps + if isinstance(size, (list, tuple)): + if len(size) not in [1, 2]: + raise ValueError( + f'Size must be an int or a 1 or 2 element tuple/list,' + f'not a {len(size)} element tuple/list.') + if isinstance(size, int): + size = [size] + self.size = size + self.max_size = max_size + self.relative = relative + self.extra_kwargs = self._default_kwargs for key in kwargs: if key in self.extra_kwargs: @@ -98,27 +126,31 @@ def compute_stats(self, sample, context=False): video_motion_scores = [] with VideoCapture(video_key) as cap: - fps = cap.get(cv2.CAP_PROP_FPS) - valid_fps = min(max(self.sampling_fps, 1), fps) - frame_interval = int(fps / valid_fps) - total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - # if cannot get the second frame, use the last one - frame_interval = min(frame_interval, total_frames - 1) + if cap.isOpened(): + fps = cap.get(cv2.CAP_PROP_FPS) + sampling_fps = min(self.sampling_fps, fps) + sampling_step = round(fps / sampling_fps) + total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + # at least two frames for computing optical flow + sampling_step = max(min(sampling_step, total_frames - 1), + 1) prev_frame = None - frame_count = -1 + frame_count = 0 while cap.isOpened(): ret, frame = cap.read() - frame_count += 1 - if not ret: # If the frame can't be read, it could be due to # a corrupt frame or reaching the end of the video. break - # skip middle frames - if frame_count % frame_interval != 0: - continue + height, width, _ = frame.shape + new_size = _compute_resized_output_size( + (height, width), self.size, self.max_size) + if new_size != (height, width): + frame = cv2.resize(frame, + new_size, + interpolation=cv2.INTER_AREA) gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) if prev_frame is None: @@ -129,14 +161,21 @@ def compute_stats(self, sample, context=False): prev_frame, gray_frame, None, **self.extra_kwargs) mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1]) frame_motion_score = np.mean(mag) + if self.relative: + frame_motion_score /= np.hypot(*flow.shape[:2]) video_motion_scores.append(frame_motion_score) prev_frame = gray_frame + # quickly skip frames + frame_count += sampling_step + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_count) + # may due to frame corruption if not video_motion_scores: unique_motion_scores[video_key] = -1 else: - unique_motion_scores[video_key] = np.mean(video_motion_scores) + unique_motion_scores[video_key] = np.mean(video_motion_scores + or [-1]) sample[Fields.stats][StatsKeys.video_motion_score] = [ unique_motion_scores[key] for key in loaded_video_keys @@ -159,3 +198,27 @@ def process(self, sample): return keep_bools.any() else: return keep_bools.all() + + +def _compute_resized_output_size( + frame_size: Tuple[int, int], + size: Optional[List[int]], + max_size: Optional[int] = None, +) -> List[int]: + h, w = frame_size + short, long = (w, h) if w <= h else (h, w) + + if size is None: # no change + new_short, new_long = short, long + elif len(size) == 1: # specified size only for the smallest edge + new_short = size[0] + new_long = int(new_short * long / short) + else: # specified both h and w + new_short, new_long = min(size), max(size) + + if max_size is not None and new_long > max_size: + new_short = int(max_size * new_short / new_long) + new_long = max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + return new_h, new_w diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index f382fa688..604bee72d 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -33,9 +33,9 @@ def setUpClass(cls): max_diff = os.getenv('TEST_MAX_DIFF', 'None') cls.maxDiff = None if max_diff == 'None' else int(max_diff) + import multiprocess + cls.original_mp_method = multiprocess.get_start_method() if is_cuda_available(): - import multiprocess - cls.original_mp_method = multiprocess.get_start_method() multiprocess.set_start_method('spawn', force=True) @classmethod diff --git a/tests/ops/filter/test_video_motion_score_filter.py b/tests/ops/filter/test_video_motion_score_filter.py index d8c8367e4..c83ce2fd0 100644 --- a/tests/ops/filter/test_video_motion_score_filter.py +++ b/tests/ops/filter/test_video_motion_score_filter.py @@ -1,7 +1,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset as Dataset +from datasets import Dataset from data_juicer.ops.filter.video_motion_score_filter import \ VideoMotionScoreFilter @@ -29,7 +29,6 @@ def _run_helper(self, op, source_list, target_list, np=1): self.assertEqual(res_list, target_list) def test_default(self): - ds_list = [{ 'videos': [self.vid1_path] }, { @@ -47,8 +46,55 @@ def test_default(self): op = VideoMotionScoreFilter() self._run_helper(op, ds_list, tgt_list) - def test_high(self): + def test_downscale(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreFilter(min_score=1.0, size=120) + self._run_helper(op, ds_list, tgt_list) + + def test_downscale_max(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreFilter(min_score=1.0, size=120, max_size=160) + self._run_helper(op, ds_list, tgt_list) + + def test_downscale_relative(self): + ds_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }, { + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'videos': [self.vid1_path] + }, { + 'videos': [self.vid2_path] + }] + op = VideoMotionScoreFilter(min_score=0.005, size=(120, 160), relative=True) + self._run_helper(op, ds_list, tgt_list) + def test_high(self): ds_list = [{ 'videos': [self.vid1_path] }, { @@ -61,7 +107,6 @@ def test_high(self): self._run_helper(op, ds_list, tgt_list) def test_low(self): - ds_list = [{ 'videos': [self.vid1_path] }, { @@ -74,7 +119,6 @@ def test_low(self): self._run_helper(op, ds_list, tgt_list) def test_middle(self): - ds_list = [{ 'videos': [self.vid1_path] }, { @@ -87,7 +131,6 @@ def test_middle(self): self._run_helper(op, ds_list, tgt_list) def test_any(self): - ds_list = [{ 'videos': [self.vid1_path, self.vid2_path] }, { @@ -106,7 +149,6 @@ def test_any(self): self._run_helper(op, ds_list, tgt_list) def test_all(self): - ds_list = [{ 'videos': [self.vid1_path, self.vid2_path] }, {