From a2612cdff70fcc282f31a62116d8837ebf551a3f Mon Sep 17 00:00:00 2001 From: Cathy0908 <30484308+Cathy0908@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:20:44 +0800 Subject: [PATCH] add op video_extract_frames_mapper (#507) * add op video_extract_frames_mapper --- configs/config_all.yaml | 5 + data_juicer/ops/base_op.py | 15 +- data_juicer/ops/mapper/__init__.py | 14 +- .../ops/mapper/video_extract_frames_mapper.py | 173 +++++++++++++ data_juicer/utils/constant.py | 1 + data_juicer/utils/mm_utils.py | 82 +++++- docs/Operators.md | 1 + docs/Operators_ZH.md | 1 + .../test_video_extract_frames_mapper.py | 242 ++++++++++++++++++ 9 files changed, 525 insertions(+), 9 deletions(-) create mode 100644 data_juicer/ops/mapper/video_extract_frames_mapper.py create mode 100644 tests/ops/mapper/test_video_extract_frames_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 756cadd81..1003c89af 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -341,6 +341,11 @@ process: horizontal_flip: false # flip frame image horizontally (left to right). vertical_flip: false # flip frame image vertically (top to bottom). mem_required: '20GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + - video_extract_frames_mapper: # extract frames from video files according to specified methods + frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". + frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + duration: 0 # The duration of each segment in seconds. If 0, frames are extracted from the entire video. If duration > 0, the video is segmented into multiple segments based on duration, and frames are extracted from each segment. + frame_dir: None # Output directory to save extracted frames. If None, a default directory based on the video file path is used. - video_face_blur_mapper: # blur faces detected in videos cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'. blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index e7b0c9cc5..72618a6bb 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -259,11 +259,22 @@ def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) num_samples = len(samples[first_key]) + + new_keys = {} for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.process_single(this_sample, *args, **kwargs) - for key in keys: - samples[key][i] = res_sample[key] + res_keys = res_sample.keys() + for key in res_keys: + if key not in keys: + if key not in new_keys: + new_keys.update({key: []}) + new_keys[key].append(res_sample[key]) + else: + samples[key][i] = res_sample[key] + + for k, v in new_keys.items(): + samples[k] = v return samples diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index 6d2a26949..3c95d1fe6 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -52,6 +52,7 @@ from .video_captioning_from_summarizer_mapper import \ VideoCaptioningFromSummarizerMapper from .video_captioning_from_video_mapper import VideoCaptioningFromVideoMapper +from .video_extract_frames_mapper import VideoExtractFramesMapper from .video_face_blur_mapper import VideoFaceBlurMapper from .video_ffmpeg_wrapped_mapper import VideoFFmpegWrappedMapper from .video_remove_watermark_mapper import VideoRemoveWatermarkMapper @@ -84,10 +85,11 @@ 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', - 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', - 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', - 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', - 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', - 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', - 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' + 'VideoCaptioningFromVideoMapper', 'VideoExtractFramesMapper', + 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', + 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', + 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', + 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', + 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', + 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/video_extract_frames_mapper.py b/data_juicer/ops/mapper/video_extract_frames_mapper.py new file mode 100644 index 000000000..4eb522abe --- /dev/null +++ b/data_juicer/ops/mapper/video_extract_frames_mapper.py @@ -0,0 +1,173 @@ +import json +import os +import os.path as osp + +from pydantic import PositiveInt + +from data_juicer.utils.constant import Fields +from data_juicer.utils.file_utils import dict_to_hash +from data_juicer.utils.mm_utils import ( + SpecialTokens, close_video, extract_key_frames, + extract_key_frames_by_seconds, extract_video_frames_uniformly, + extract_video_frames_uniformly_by_seconds, load_data_with_context, + load_video) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_extract_frames_mapper' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoExtractFramesMapper(Mapper): + """Mapper to extract frames from video files according to specified methods. + Extracted Frames Data Format: + The data format for the extracted frames is a dictionary mapping + video key to extracted frames directory where the extracted + frames are saved. The dictionary follows the structure: + { + "video_key_1": "/${frame_dir}/video_key_1_filename/", + "video_key_2": "/${frame_dir}/video_key_2_filename/", + ... + } + """ + + _batched_op = True + + def __init__( + self, + frame_sampling_method: str = 'all_keyframes', + frame_num: PositiveInt = 3, + duration: float = 0, + frame_dir: str = None, + frame_key=Fields.video_frames, + *args, + **kwargs, + ): + """ + Initialization method. + :param frame_sampling_method: sampling method of extracting frame + videos from the videos. Should be one of + ["all_keyframes", "uniform"]. + The former one extracts all key frames (the number + of which depends on the duration of the video) and the latter + one extract specified number of frames uniformly from the video. + If "duration" > 0, frame_sampling_method acts on every segment. + Default: "all_keyframes". + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + If "duration" > 0, frame_num is the number of frames per segment. + :param duration: The duration of each segment in seconds. + If 0, frames are extracted from the entire video. + If duration > 0, the video is segmented into multiple segments + based on duration, and frames are extracted from each segment. + :param frame_dir: Output directory to save extracted frames. + If None, a default directory based on the video file path is used. + :param frame_key: The name of field to save generated frames info. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method ' + f'[{frame_sampling_method}] is not supported. ' + f'Can only be one of ["all_keyframes", "uniform"].') + + self.frame_dir = frame_dir + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + self.duration = duration + self.frame_key = frame_key + self.frame_fname_template = 'frame_{}.jpg' + + def _get_default_frame_dir(self, original_filepath): + original_dir = os.path.dirname(original_filepath) + dir_token = f'/{Fields.multimodal_data_output_dir}/' + if dir_token in original_dir: + original_dir = original_dir.split(dir_token)[0] + saved_dir = os.path.join( + original_dir, f'{Fields.multimodal_data_output_dir}/{OP_NAME}') + original_filename = osp.splitext(osp.basename(original_filepath))[0] + hash_val = dict_to_hash(self._init_parameters) + + return osp.join(saved_dir, + f'{original_filename}__dj_hash_#{hash_val}#') + + def process_single(self, sample, context=False): + # check if it's generated already + if self.frame_key in sample: + return sample + + # there is no videos in this sample + if self.video_key not in sample or not sample[self.video_key]: + return [] + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + video_to_frame_dir = {} + text = sample[self.text_key] + offset = 0 + + for chunk in text.split(SpecialTokens.eoc): + video_count = chunk.count(SpecialTokens.video) + # no video or no text + if video_count == 0 or len(chunk) == 0: + continue + else: + for video_key in loaded_video_keys[offset:offset + + video_count]: + video = videos[video_key] + # extract frame videos + if self.frame_sampling_method == 'all_keyframes': + if self.duration: + frames = extract_key_frames_by_seconds( + video, self.duration) + else: + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + if self.duration: + frames = extract_video_frames_uniformly_by_seconds( + video, self.frame_num, duration=self.duration) + else: + frames = extract_video_frames_uniformly( + video, self.frame_num) + else: + raise ValueError(f'Not support sampling method \ + `{self.frame_sampling_method}`.') + frames = [frame.to_image() for frame in frames] + + if self.frame_dir: + frame_dir = osp.join( + self.frame_dir, + osp.splitext(osp.basename(video_key))[0]) + else: + # video path as frames directory + frame_dir = self._get_default_frame_dir(video_key) + os.makedirs(frame_dir, exist_ok=True) + video_to_frame_dir[video_key] = frame_dir + + for i, frame in enumerate(frames): + frame_path = osp.join( + frame_dir, self.frame_fname_template.format(i)) + if not os.path.exists(frame_path): + frame.save(frame_path) + + offset += video_count + + if not context: + for vid_key in videos: + close_video(videos[vid_key]) + + sample[self.frame_key] = json.dumps(video_to_frame_dir) + + return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ab88035b9..8bc78ad5a 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -16,6 +16,7 @@ class Fields(object): context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' + video_frames = DEFAULT_PREFIX + 'video_frames__' # video_frame_tags video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 4e9aff894..49e5046ab 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -1,5 +1,6 @@ import base64 import datetime +import io import os import re import shutil @@ -321,7 +322,11 @@ def cut_video_by_seconds( container = input_video # create the output video - output_container = load_video(output_video, 'w') + if output_video: + output_container = load_video(output_video, 'w') + else: + output_buffer = io.BytesIO() + output_container = av.open(output_buffer, mode='w', format='mp4') # add the video stream into the output video according to input video input_video_stream = container.streams.video[0] @@ -390,6 +395,11 @@ def cut_video_by_seconds( if isinstance(input_video, str): close_video(container) close_video(output_container) + + if not output_video: + output_buffer.seek(0) + return output_buffer + if not os.path.exists(output_video): logger.warning(f'This video could not be successfully cut in ' f'[{start_seconds}, {end_seconds}] seconds. ' @@ -463,6 +473,39 @@ def process_each_frame(input_video: Union[str, av.container.InputContainer], if isinstance(input_video, str) else input_video.name) +def extract_key_frames_by_seconds( + input_video: Union[str, av.container.InputContainer], + duration: float = 1): + """Extract key frames by seconds. + :param input_video: input video path or av.container.InputContainer. + :param duration: duration of each video split in seconds. + """ + # load the input video + if isinstance(input_video, str): + container = load_video(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + video_duration = get_video_duration(container) + timestamps = np.arange(0, video_duration, duration).tolist() + + all_key_frames = [] + for i in range(1, len(timestamps)): + output_buffer = cut_video_by_seconds(container, None, + timestamps[i - 1], timestamps[i]) + if output_buffer: + cut_inp_container = av.open(output_buffer, format='mp4', mode='r') + key_frames = extract_key_frames(cut_inp_container) + all_key_frames.extend(key_frames) + close_video(cut_inp_container) + + return all_key_frames + + def extract_key_frames(input_video: Union[str, av.container.InputContainer]): """ Extract key frames from the input video. If there is no keyframes in the @@ -516,6 +559,43 @@ def get_key_frame_seconds(input_video: Union[str, return ts +def extract_video_frames_uniformly_by_seconds( + input_video: Union[str, av.container.InputContainer], + frame_num: PositiveInt, + duration: float = 1): + """Extract video frames uniformly by seconds. + :param input_video: input video path or av.container.InputContainer. + :param frame_num: the number of frames to be extracted uniformly from + each video split by duration. + :param duration: duration of each video split in seconds. + """ + # load the input video + if isinstance(input_video, str): + container = load_video(input_video) + elif isinstance(input_video, av.container.InputContainer): + container = input_video + else: + raise ValueError(f'Unsupported type of input_video. Should be one of ' + f'[str, av.container.InputContainer], but given ' + f'[{type(input_video)}].') + + video_duration = get_video_duration(container) + timestamps = np.arange(0, video_duration, duration).tolist() + + all_frames = [] + for i in range(1, len(timestamps)): + output_buffer = cut_video_by_seconds(container, None, + timestamps[i - 1], timestamps[i]) + if output_buffer: + cut_inp_container = av.open(output_buffer, format='mp4', mode='r') + key_frames = extract_video_frames_uniformly(cut_inp_container, + frame_num=frame_num) + all_frames.extend(key_frames) + close_video(cut_inp_container) + + return all_frames + + def extract_video_frames_uniformly( input_video: Union[str, av.container.InputContainer], frame_num: PositiveInt, diff --git a/docs/Operators.md b/docs/Operators.md index 1d6b63ec5..04a2da380 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -106,6 +106,7 @@ All the specific operators are listed below, each featured with several capabili | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | | video_captioning_from_video_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on another model (video-blip) and sampled video frame within the original sample | [code](../data_juicer/ops/mapper/video_captioning_from_video_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_video_mapper.py) | +| video_extract_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | extract frames from video files according to specified methods | [code](../data_juicer/ops/mapper/video_extract_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_extract_frames_mapper.py) | | video_face_blur_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Blur faces detected in videos | [code](../data_juicer/ops/mapper/video_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_video_face_blur_mapper.py) | | video_ffmpeg_wrapped_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Simple wrapper to run a FFmpeg video filter | [code](../data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py) | | video_remove_watermark_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Remove the watermarks in videos given regions | [code](../data_juicer/ops/mapper/video_remove_watermark_mapper.py) | [tests](../tests/ops/mapper/test_video_remove_watermark_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 01f0bdb0a..011ff5a64 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -105,6 +105,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | | video_captioning_from_summarizer_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 通过对多种不同方式生成的文本进行摘要以生成样本的标题(从视频/音频/帧生成标题,从音频/帧生成标签,...) | [code](../data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py) | | video_captioning_from_video_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是根据另一个辅助模型(video-blip)和原始样本中的视频中指定帧的图像。 | [code](../data_juicer/ops/mapper/video_captioning_from_video_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_video_mapper.py) | +| video_extract_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从视频中抽帧。 | [code](../data_juicer/ops/mapper/video_extract_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_extract_frames_mapper.py) | | video_face_blur_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 对视频中的人脸进行模糊处理 | [code](../data_juicer/ops/mapper/video_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_video_face_blur_mapper.py) | | video_ffmpeg_wrapped_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 运行 FFmpeg 视频过滤器的简单封装 | [code](../data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py) | [tests](../tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py) | | video_remove_watermark_mapper | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | 去除视频中给定区域的水印 | [code](../data_juicer/ops/mapper/video_remove_watermark_mapper.py) | [tests](../tests/ops/mapper/test_video_remove_watermark_mapper.py) | diff --git a/tests/ops/mapper/test_video_extract_frames_mapper.py b/tests/ops/mapper/test_video_extract_frames_mapper.py new file mode 100644 index 000000000..7ae2dd29f --- /dev/null +++ b/tests/ops/mapper/test_video_extract_frames_mapper.py @@ -0,0 +1,242 @@ +import os +import os.path as osp +import re +import copy +import unittest +import json +import tempfile +import shutil +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.video_extract_frames_mapper import \ + VideoExtractFramesMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoExtractFramesMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + tmp_dir = tempfile.TemporaryDirectory().name + + def tearDown(self): + super().tearDown() + if osp.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) + + default_frame_dir_prefix = self._get_default_frame_dir_prefix() + if osp.exists(default_frame_dir_prefix): + shutil.rmtree(osp.dirname(default_frame_dir_prefix)) + + def _get_default_frame_dir_prefix(self): + from data_juicer.ops.mapper.video_extract_frames_mapper import OP_NAME + default_frame_dir_prefix = osp.abspath(osp.join(self.data_path, + f'{Fields.multimodal_data_output_dir}/{OP_NAME}/')) + return default_frame_dir_prefix + + def _get_frames_list(self, filepath, frame_dir, frame_num): + frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0]) + frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)] + return frames_list + + def _get_frames_dir(self, filepath, frame_dir): + frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0]) + return frames_dir + + def _sort_files(self, file_list): + return sorted(file_list, key=lambda x: int(re.search(r'(\d+)', x).group())) + + def test_duration(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + + frame_num = 2 + frame_dir=os.path.join(self.tmp_dir, 'test1') + vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir) + vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir) + vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) + tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + duration=0, + frame_dir=frame_dir) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=1) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(frame_num)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(frame_num)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(frame_num)]) + + def test_uniform_sampling(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + frame_num = 3 + frame_dir=os.path.join(self.tmp_dir, 'test1') + vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir) + vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir) + vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) + tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + duration=10, + frame_dir=frame_dir) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=1) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(3)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(6)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(12)]) + + def test_all_keyframes_sampling(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + \ + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + frame_dir=os.path.join(self.tmp_dir, 'test2') + vid1_frame_dir = self._get_frames_dir(self.vid1_path, frame_dir) + vid2_frame_dir = self._get_frames_dir(self.vid2_path, frame_dir) + vid3_frame_dir = self._get_frames_dir(self.vid3_path, frame_dir) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: + json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({ + self.vid2_path: vid2_frame_dir, + self.vid3_path: vid3_frame_dir + })}) + tgt_list[2].update({Fields.video_frames: + json.dumps({self.vid3_path: vid3_frame_dir})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='all_keyframes', + frame_dir=frame_dir, + duration=5) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=2) + res_list = dataset.to_list() + self.assertEqual(res_list, tgt_list) + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(4)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(5)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(13)]) + + def test_default_frame_dir(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + + frame_num = 2 + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + duration=5, + ) + + vid1_frame_dir = op._get_default_frame_dir(self.vid1_path) + vid2_frame_dir = op._get_default_frame_dir(self.vid2_path) + vid3_frame_dir = op._get_default_frame_dir(self.vid3_path) + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: json.dumps({self.vid1_path: vid1_frame_dir})}) + tgt_list[1].update({Fields.video_frames: json.dumps({self.vid2_path: vid2_frame_dir})}) + tgt_list[2].update({Fields.video_frames: json.dumps({self.vid3_path: vid3_frame_dir})}) + + dataset = Dataset.from_list(ds_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=1) + res_list = dataset.to_list() + + frame_dir_prefix = self._get_default_frame_dir_prefix() + self.assertIn(frame_dir_prefix, osp.abspath(vid1_frame_dir)) + self.assertIn(frame_dir_prefix, osp.abspath(vid2_frame_dir)) + self.assertIn(frame_dir_prefix, osp.abspath(vid3_frame_dir)) + + self.assertEqual(res_list, tgt_list) + + self.assertListEqual( + self._sort_files(os.listdir(vid1_frame_dir)), + [f'frame_{i}.jpg' for i in range(4)]) + self.assertListEqual( + self._sort_files(os.listdir(vid2_frame_dir)), + [f'frame_{i}.jpg' for i in range(8)]) + self.assertListEqual( + self._sort_files(os.listdir(vid3_frame_dir)), + [f'frame_{i}.jpg' for i in range(18)]) + + +if __name__ == '__main__': + unittest.main()