Skip to content

Commit

Permalink
add op video_extract_frames_mapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Cathy0908 committed Dec 10, 2024
1 parent 4ab426e commit 5191cf9
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 8 deletions.
15 changes: 13 additions & 2 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,11 +257,22 @@ def process_batched(self, samples, *args, **kwargs):
keys = samples.keys()
first_key = next(iter(keys))
num_samples = len(samples[first_key])

new_keys = {}
for i in range(num_samples):
this_sample = {key: samples[key][i] for key in keys}
res_sample = self.process_single(this_sample, *args, **kwargs)
for key in keys:
samples[key][i] = res_sample[key]
res_keys = res_sample.keys()
for key in res_keys:
if key not in keys:
if key not in new_keys:
new_keys.update({key: []})
new_keys[key].append(res_sample[key])
else:
samples[key][i] = res_sample[key]

for k, v in new_keys.items():
samples[k] = v

return samples

Expand Down
13 changes: 7 additions & 6 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .video_captioning_from_summarizer_mapper import \
VideoCaptioningFromSummarizerMapper
from .video_captioning_from_video_mapper import VideoCaptioningFromVideoMapper
from .video_extract_frames_mapper import VideoExtractFramesMapper
from .video_face_blur_mapper import VideoFaceBlurMapper
from .video_ffmpeg_wrapped_mapper import VideoFFmpegWrappedMapper
from .video_remove_watermark_mapper import VideoRemoveWatermarkMapper
Expand Down Expand Up @@ -82,10 +83,10 @@
'ReplaceContentMapper', 'SentenceSplitMapper', 'TextChunkMapper',
'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper',
'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper',
'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper',
'VideoRemoveWatermarkMapper', 'VideoResizeAspectRatioMapper',
'VideoResizeResolutionMapper', 'VideoSplitByDurationMapper',
'VideoSplitByKeyFrameMapper', 'VideoSplitBySceneMapper',
'VideoTaggingFromAudioMapper', 'VideoTaggingFromFramesMapper',
'WhitespaceNormalizationMapper'
'VideoExtractFramesMapper', 'VideoFFmpegWrappedMapper',
'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper',
'VideoResizeAspectRatioMapper', 'VideoResizeResolutionMapper',
'VideoSplitByDurationMapper', 'VideoSplitByKeyFrameMapper',
'VideoSplitBySceneMapper', 'VideoTaggingFromAudioMapper',
'VideoTaggingFromFramesMapper', 'WhitespaceNormalizationMapper'
]
163 changes: 163 additions & 0 deletions data_juicer/ops/mapper/video_extract_frames_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import json
import os
import os.path as osp

from pydantic import PositiveInt

from data_juicer.utils.constant import Fields
from data_juicer.utils.file_utils import create_directory_if_not_exists
from data_juicer.utils.mm_utils import (SpecialTokens, close_video,
extract_key_frames,
extract_video_frames_uniformly,
load_data_with_context, load_video)

from ..base_op import OPERATORS, Mapper
from ..op_fusion import LOADED_VIDEOS

OP_NAME = 'video_extract_frames_mapper'


@OPERATORS.register_module(OP_NAME)
@LOADED_VIDEOS.register_module(OP_NAME)
class VideoExtractFramesMapper(Mapper):
"""Mapper to extract frames from video files according to specified methods.
Extracted Frames Data Format:
The data format for the extracted frames is a dictionary mapping
video keys to lists of file paths where the extracted frames are saved.
The dictionary follows the structure:
{
"video_key_1": [
"/${frame_dir}/video_key_1_filename/frame_1.jpg",
"/${frame_dir}/video_key_1_filename/frame_2.jpg",
...],
"video_key_2": [
"/${frame_dir}/video_key_2_filename/frame_1.jpg",
"/${frame_dir}/video_key_2_filename/frame_2.jpg",
...],
...
}
"""

_batched_op = True

def __init__(
self,
frame_sampling_method: str = 'all_keyframes',
frame_num: PositiveInt = 3,
frame_dir: str = None,
frame_key=Fields.video_frames,
*args,
**kwargs,
):
"""
Initialization method.
:param frame_sampling_method: sampling method of extracting frame
videos from the videos. Should be one of
["all_keyframes", "uniform"].
The former one extracts all key frames (the number
of which depends on the duration of the video) and the latter
one extract specified number of frames uniformly from the video.
Default: "all_keyframes".
:param frame_num: the number of frames to be extracted uniformly from
the video. Only works when frame_sampling_method is "uniform". If
it's 1, only the middle frame will be extracted. If it's 2, only
the first and the last frames will be extracted. If it's larger
than 2, in addition to the first and the last frames, other frames
will be extracted uniformly within the video duration.
:param frame_dir: Output directory to save extracted frames.
If None, a default directory based on the video file path is used.
:param frame_key: The name of field to save generated frames info.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self._init_parameters = self.remove_extra_parameters(locals())

if frame_sampling_method not in ['all_keyframes', 'uniform']:
raise ValueError(
f'Frame sampling method '
f'[{frame_sampling_method}] is not supported. '
f'Can only be one of ["all_keyframes", "uniform"].')

self.frame_dir = frame_dir
self.frame_sampling_method = frame_sampling_method
self.frame_num = frame_num
self.frame_key = frame_key
self.frame_fname_template = 'frame_{}.jpg'

def _get_default_frame_dir(self, original_filepath):
original_dir = os.path.dirname(original_filepath)
dir_token = f'/{Fields.multimodal_data_output_dir}/'
if dir_token in original_dir:
original_dir = original_dir.split(dir_token)[0]
new_dir = os.path.join(
original_dir, f'{Fields.multimodal_data_output_dir}/{OP_NAME}')
create_directory_if_not_exists(new_dir)
return osp.join(new_dir,
osp.splitext(osp.basename(original_filepath))[0])

def process_single(self, sample, context=False):
# check if it's generated already
if self.frame_key in sample:
return sample

# there is no videos in this sample
if self.video_key not in sample or not sample[self.video_key]:
return []

# load videos
loaded_video_keys = sample[self.video_key]
sample, videos = load_data_with_context(sample, context,
loaded_video_keys, load_video)
video_to_frames = {}
text = sample[self.text_key]
offset = 0

for chunk in text.split(SpecialTokens.eoc):
video_count = chunk.count(SpecialTokens.video)
# no video or no text
if video_count == 0 or len(chunk) == 0:
continue
else:
for video_key in loaded_video_keys[offset:offset +
video_count]:
video = videos[video_key]
# extract frame videos
if self.frame_sampling_method == 'all_keyframes':
frames = extract_key_frames(video)
elif self.frame_sampling_method == 'uniform':
frames = extract_video_frames_uniformly(
video, self.frame_num)
else:
raise ValueError(f'Not support sampling method \
`{self.frame_sampling_method}`.')
frames = [frame.to_image() for frame in frames]

if self.frame_dir:
frame_dir = osp.join(
self.frame_dir,
osp.splitext(osp.basename(video_key))[0])
else:
# video path as frames directory
frame_dir = self._get_default_frame_dir(video_key)
os.makedirs(frame_dir, exist_ok=True)

video_to_frames[video_key] = []
for i, frame in enumerate(frames):
frame_path = osp.join(
frame_dir, self.frame_fname_template.format(i))
if not os.path.exists(frame_path):
frame.save(frame_path)

video_to_frames[video_key].append(frame_path)

offset += video_count

if not context:
for vid_key in videos:
close_video(videos[vid_key])

sample[self.frame_key] = json.dumps(video_to_frames)
# sample[self.frame_key] = video_to_frames

return sample
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class Fields(object):
context = DEFAULT_PREFIX + 'context__'
suffix = DEFAULT_PREFIX + 'suffix__'

video_frames = DEFAULT_PREFIX + 'video_frames__'
# video_frame_tags
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'
Expand Down
108 changes: 108 additions & 0 deletions tests/ops/mapper/test_video_extract_frames_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import os
import os.path as osp
import copy
import unittest
import json
import tempfile
import shutil
from data_juicer.core.data import NestedDataset as Dataset
from data_juicer.ops.mapper.video_extract_frames_mapper import \
VideoExtractFramesMapper
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class VideoExtractFramesMapperTest(DataJuicerTestCaseBase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data')
vid1_path = os.path.join(data_path, 'video1.mp4')
vid2_path = os.path.join(data_path, 'video2.mp4')
vid3_path = os.path.join(data_path, 'video3.mp4')
tmp_dir = tempfile.TemporaryDirectory().name

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tmp_dir)

def _run_video_extract_frames_mapper(self,
op,
source_list,
target_list,
num_proc=1):
dataset = Dataset.from_list(source_list)
dataset = dataset.map(op.process, batch_size=2, num_proc=num_proc)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def _get_frames_list(self, filepath, frame_dir, frame_num):
frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0])
frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)]
return frames_list

def test_uniform_sampling(self):
ds_list = [{
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [self.vid1_path]
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}',
'videos': [self.vid2_path]
}, {
'text':
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid3_path]
}]
frame_num = 3
frame_dir=os.path.join(self.tmp_dir, 'test1')

tgt_list = copy.deepcopy(ds_list)
tgt_list[0].update({Fields.video_frames:
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, frame_num)})})
tgt_list[1].update({Fields.video_frames:
json.dumps({self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, frame_num)})})
tgt_list[2].update({Fields.video_frames:
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, frame_num)})})

op = VideoExtractFramesMapper(
frame_sampling_method='uniform',
frame_num=frame_num,
frame_dir=frame_dir)
self._run_video_extract_frames_mapper(op, ds_list, tgt_list)


def test_all_keyframes_sampling(self):
ds_list = [{
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。',
'videos': [self.vid1_path]
}, {
'text':
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + \
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid2_path, self.vid3_path]
}, {
'text':
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}',
'videos': [self.vid3_path]
}]
frame_dir=os.path.join(self.tmp_dir, 'test2')

tgt_list = copy.deepcopy(ds_list)
tgt_list[0].update({Fields.video_frames:
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, 3)})})
tgt_list[1].update({Fields.video_frames: json.dumps({
self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, 3),
self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)
})})
tgt_list[2].update({Fields.video_frames:
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)})})

op = VideoExtractFramesMapper(
frame_sampling_method='all_keyframes',
frame_dir=frame_dir)
self._run_video_extract_frames_mapper(op, ds_list, tgt_list)



if __name__ == '__main__':
unittest.main()

0 comments on commit 5191cf9

Please sign in to comment.