From 5191cf96f6b90274a2c39393dc91952474d04e48 Mon Sep 17 00:00:00 2001 From: "jiangnana.jnn" Date: Tue, 10 Dec 2024 12:44:39 +0800 Subject: [PATCH] add op video_extract_frames_mapper --- data_juicer/ops/base_op.py | 15 +- data_juicer/ops/mapper/__init__.py | 13 +- .../ops/mapper/video_extract_frames_mapper.py | 163 ++++++++++++++++++ data_juicer/utils/constant.py | 1 + .../test_video_extract_frames_mapper.py | 108 ++++++++++++ 5 files changed, 292 insertions(+), 8 deletions(-) create mode 100644 data_juicer/ops/mapper/video_extract_frames_mapper.py create mode 100644 tests/ops/mapper/test_video_extract_frames_mapper.py diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index 13f3b61ae..1bfc87fca 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -257,11 +257,22 @@ def process_batched(self, samples, *args, **kwargs): keys = samples.keys() first_key = next(iter(keys)) num_samples = len(samples[first_key]) + + new_keys = {} for i in range(num_samples): this_sample = {key: samples[key][i] for key in keys} res_sample = self.process_single(this_sample, *args, **kwargs) - for key in keys: - samples[key][i] = res_sample[key] + res_keys = res_sample.keys() + for key in res_keys: + if key not in keys: + if key not in new_keys: + new_keys.update({key: []}) + new_keys[key].append(res_sample[key]) + else: + samples[key][i] = res_sample[key] + + for k, v in new_keys.items(): + samples[k] = v return samples diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index db4f54e10..e63ba50d9 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -50,6 +50,7 @@ from .video_captioning_from_summarizer_mapper import \ VideoCaptioningFromSummarizerMapper from .video_captioning_from_video_mapper import VideoCaptioningFromVideoMapper +from .video_extract_frames_mapper import VideoExtractFramesMapper from .video_face_blur_mapper import VideoFaceBlurMapper from .video_ffmpeg_wrapped_mapper import VideoFFmpegWrappedMapper from .video_remove_watermark_mapper import VideoRemoveWatermarkMapper @@ -82,10 +83,10 @@ 'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper', 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', - 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', - 'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper', - 'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper', - 'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper', - 'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper', - 'WhitespaceNormalizationMapper' + 'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper', + 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', + 'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper', + 'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper', + 'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper', + 'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper' ] diff --git a/data_juicer/ops/mapper/video_extract_frames_mapper.py b/data_juicer/ops/mapper/video_extract_frames_mapper.py new file mode 100644 index 000000000..b1d56d7c6 --- /dev/null +++ b/data_juicer/ops/mapper/video_extract_frames_mapper.py @@ -0,0 +1,163 @@ +import json +import os +import os.path as osp + +from pydantic import PositiveInt + +from data_juicer.utils.constant import Fields +from data_juicer.utils.file_utils import create_directory_if_not_exists +from data_juicer.utils.mm_utils import (SpecialTokens, close_video, + extract_key_frames, + extract_video_frames_uniformly, + load_data_with_context, load_video) + +from ..base_op import OPERATORS, Mapper +from ..op_fusion import LOADED_VIDEOS + +OP_NAME = 'video_extract_frames_mapper' + + +@OPERATORS.register_module(OP_NAME) +@LOADED_VIDEOS.register_module(OP_NAME) +class VideoExtractFramesMapper(Mapper): + """Mapper to extract frames from video files according to specified methods. + Extracted Frames Data Format: + The data format for the extracted frames is a dictionary mapping + video keys to lists of file paths where the extracted frames are saved. + The dictionary follows the structure: + { + "video_key_1": [ + "/${frame_dir}/video_key_1_filename/frame_1.jpg", + "/${frame_dir}/video_key_1_filename/frame_2.jpg", + ...], + "video_key_2": [ + "/${frame_dir}/video_key_2_filename/frame_1.jpg", + "/${frame_dir}/video_key_2_filename/frame_2.jpg", + ...], + ... + } + """ + + _batched_op = True + + def __init__( + self, + frame_sampling_method: str = 'all_keyframes', + frame_num: PositiveInt = 3, + frame_dir: str = None, + frame_key=Fields.video_frames, + *args, + **kwargs, + ): + """ + Initialization method. + :param frame_sampling_method: sampling method of extracting frame + videos from the videos. Should be one of + ["all_keyframes", "uniform"]. + The former one extracts all key frames (the number + of which depends on the duration of the video) and the latter + one extract specified number of frames uniformly from the video. + Default: "all_keyframes". + :param frame_num: the number of frames to be extracted uniformly from + the video. Only works when frame_sampling_method is "uniform". If + it's 1, only the middle frame will be extracted. If it's 2, only + the first and the last frames will be extracted. If it's larger + than 2, in addition to the first and the last frames, other frames + will be extracted uniformly within the video duration. + :param frame_dir: Output directory to save extracted frames. + If None, a default directory based on the video file path is used. + :param frame_key: The name of field to save generated frames info. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self._init_parameters = self.remove_extra_parameters(locals()) + + if frame_sampling_method not in ['all_keyframes', 'uniform']: + raise ValueError( + f'Frame sampling method ' + f'[{frame_sampling_method}] is not supported. ' + f'Can only be one of ["all_keyframes", "uniform"].') + + self.frame_dir = frame_dir + self.frame_sampling_method = frame_sampling_method + self.frame_num = frame_num + self.frame_key = frame_key + self.frame_fname_template = 'frame_{}.jpg' + + def _get_default_frame_dir(self, original_filepath): + original_dir = os.path.dirname(original_filepath) + dir_token = f'/{Fields.multimodal_data_output_dir}/' + if dir_token in original_dir: + original_dir = original_dir.split(dir_token)[0] + new_dir = os.path.join( + original_dir, f'{Fields.multimodal_data_output_dir}/{OP_NAME}') + create_directory_if_not_exists(new_dir) + return osp.join(new_dir, + osp.splitext(osp.basename(original_filepath))[0]) + + def process_single(self, sample, context=False): + # check if it's generated already + if self.frame_key in sample: + return sample + + # there is no videos in this sample + if self.video_key not in sample or not sample[self.video_key]: + return [] + + # load videos + loaded_video_keys = sample[self.video_key] + sample, videos = load_data_with_context(sample, context, + loaded_video_keys, load_video) + video_to_frames = {} + text = sample[self.text_key] + offset = 0 + + for chunk in text.split(SpecialTokens.eoc): + video_count = chunk.count(SpecialTokens.video) + # no video or no text + if video_count == 0 or len(chunk) == 0: + continue + else: + for video_key in loaded_video_keys[offset:offset + + video_count]: + video = videos[video_key] + # extract frame videos + if self.frame_sampling_method == 'all_keyframes': + frames = extract_key_frames(video) + elif self.frame_sampling_method == 'uniform': + frames = extract_video_frames_uniformly( + video, self.frame_num) + else: + raise ValueError(f'Not support sampling method \ + `{self.frame_sampling_method}`.') + frames = [frame.to_image() for frame in frames] + + if self.frame_dir: + frame_dir = osp.join( + self.frame_dir, + osp.splitext(osp.basename(video_key))[0]) + else: + # video path as frames directory + frame_dir = self._get_default_frame_dir(video_key) + os.makedirs(frame_dir, exist_ok=True) + + video_to_frames[video_key] = [] + for i, frame in enumerate(frames): + frame_path = osp.join( + frame_dir, self.frame_fname_template.format(i)) + if not os.path.exists(frame_path): + frame.save(frame_path) + + video_to_frames[video_key].append(frame_path) + + offset += video_count + + if not context: + for vid_key in videos: + close_video(videos[vid_key]) + + sample[self.frame_key] = json.dumps(video_to_frames) + # sample[self.frame_key] = video_to_frames + + return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index ab88035b9..8bc78ad5a 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -16,6 +16,7 @@ class Fields(object): context = DEFAULT_PREFIX + 'context__' suffix = DEFAULT_PREFIX + 'suffix__' + video_frames = DEFAULT_PREFIX + 'video_frames__' # video_frame_tags video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' diff --git a/tests/ops/mapper/test_video_extract_frames_mapper.py b/tests/ops/mapper/test_video_extract_frames_mapper.py new file mode 100644 index 000000000..63f5ffbc3 --- /dev/null +++ b/tests/ops/mapper/test_video_extract_frames_mapper.py @@ -0,0 +1,108 @@ +import os +import os.path as osp +import copy +import unittest +import json +import tempfile +import shutil +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.video_extract_frames_mapper import \ + VideoExtractFramesMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import SpecialTokens +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class VideoExtractFramesMapperTest(DataJuicerTestCaseBase): + + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') + vid1_path = os.path.join(data_path, 'video1.mp4') + vid2_path = os.path.join(data_path, 'video2.mp4') + vid3_path = os.path.join(data_path, 'video3.mp4') + tmp_dir = tempfile.TemporaryDirectory().name + + def tearDown(self): + super().tearDown() + shutil.rmtree(self.tmp_dir) + + def _run_video_extract_frames_mapper(self, + op, + source_list, + target_list, + num_proc=1): + dataset = Dataset.from_list(source_list) + dataset = dataset.map(op.process, batch_size=2, num_proc=num_proc) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def _get_frames_list(self, filepath, frame_dir, frame_num): + frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0]) + frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)] + return frames_list + + def test_uniform_sampling(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + frame_num = 3 + frame_dir=os.path.join(self.tmp_dir, 'test1') + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: + json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, frame_num)})}) + tgt_list[1].update({Fields.video_frames: + json.dumps({self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, frame_num)})}) + tgt_list[2].update({Fields.video_frames: + json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, frame_num)})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='uniform', + frame_num=frame_num, + frame_dir=frame_dir) + self._run_video_extract_frames_mapper(op, ds_list, tgt_list) + + + def test_all_keyframes_sampling(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + \ + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid2_path, self.vid3_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + frame_dir=os.path.join(self.tmp_dir, 'test2') + + tgt_list = copy.deepcopy(ds_list) + tgt_list[0].update({Fields.video_frames: + json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, 3)})}) + tgt_list[1].update({Fields.video_frames: json.dumps({ + self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, 3), + self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6) + })}) + tgt_list[2].update({Fields.video_frames: + json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)})}) + + op = VideoExtractFramesMapper( + frame_sampling_method='all_keyframes', + frame_dir=frame_dir) + self._run_video_extract_frames_mapper(op, ds_list, tgt_list) + + + +if __name__ == '__main__': + unittest.main()