-
Notifications
You must be signed in to change notification settings - Fork 186
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
292 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import json | ||
import os | ||
import os.path as osp | ||
|
||
from pydantic import PositiveInt | ||
|
||
from data_juicer.utils.constant import Fields | ||
from data_juicer.utils.file_utils import create_directory_if_not_exists | ||
from data_juicer.utils.mm_utils import (SpecialTokens, close_video, | ||
extract_key_frames, | ||
extract_video_frames_uniformly, | ||
load_data_with_context, load_video) | ||
|
||
from ..base_op import OPERATORS, Mapper | ||
from ..op_fusion import LOADED_VIDEOS | ||
|
||
OP_NAME = 'video_extract_frames_mapper' | ||
|
||
|
||
@OPERATORS.register_module(OP_NAME) | ||
@LOADED_VIDEOS.register_module(OP_NAME) | ||
class VideoExtractFramesMapper(Mapper): | ||
"""Mapper to extract frames from video files according to specified methods. | ||
Extracted Frames Data Format: | ||
The data format for the extracted frames is a dictionary mapping | ||
video keys to lists of file paths where the extracted frames are saved. | ||
The dictionary follows the structure: | ||
{ | ||
"video_key_1": [ | ||
"/${frame_dir}/video_key_1_filename/frame_1.jpg", | ||
"/${frame_dir}/video_key_1_filename/frame_2.jpg", | ||
...], | ||
"video_key_2": [ | ||
"/${frame_dir}/video_key_2_filename/frame_1.jpg", | ||
"/${frame_dir}/video_key_2_filename/frame_2.jpg", | ||
...], | ||
... | ||
} | ||
""" | ||
|
||
_batched_op = True | ||
|
||
def __init__( | ||
self, | ||
frame_sampling_method: str = 'all_keyframes', | ||
frame_num: PositiveInt = 3, | ||
frame_dir: str = None, | ||
frame_key=Fields.video_frames, | ||
*args, | ||
**kwargs, | ||
): | ||
""" | ||
Initialization method. | ||
:param frame_sampling_method: sampling method of extracting frame | ||
videos from the videos. Should be one of | ||
["all_keyframes", "uniform"]. | ||
The former one extracts all key frames (the number | ||
of which depends on the duration of the video) and the latter | ||
one extract specified number of frames uniformly from the video. | ||
Default: "all_keyframes". | ||
:param frame_num: the number of frames to be extracted uniformly from | ||
the video. Only works when frame_sampling_method is "uniform". If | ||
it's 1, only the middle frame will be extracted. If it's 2, only | ||
the first and the last frames will be extracted. If it's larger | ||
than 2, in addition to the first and the last frames, other frames | ||
will be extracted uniformly within the video duration. | ||
:param frame_dir: Output directory to save extracted frames. | ||
If None, a default directory based on the video file path is used. | ||
:param frame_key: The name of field to save generated frames info. | ||
:param args: extra args | ||
:param kwargs: extra args | ||
""" | ||
super().__init__(*args, **kwargs) | ||
self._init_parameters = self.remove_extra_parameters(locals()) | ||
|
||
if frame_sampling_method not in ['all_keyframes', 'uniform']: | ||
raise ValueError( | ||
f'Frame sampling method ' | ||
f'[{frame_sampling_method}] is not supported. ' | ||
f'Can only be one of ["all_keyframes", "uniform"].') | ||
|
||
self.frame_dir = frame_dir | ||
self.frame_sampling_method = frame_sampling_method | ||
self.frame_num = frame_num | ||
self.frame_key = frame_key | ||
self.frame_fname_template = 'frame_{}.jpg' | ||
|
||
def _get_default_frame_dir(self, original_filepath): | ||
original_dir = os.path.dirname(original_filepath) | ||
dir_token = f'/{Fields.multimodal_data_output_dir}/' | ||
if dir_token in original_dir: | ||
original_dir = original_dir.split(dir_token)[0] | ||
new_dir = os.path.join( | ||
original_dir, f'{Fields.multimodal_data_output_dir}/{OP_NAME}') | ||
create_directory_if_not_exists(new_dir) | ||
return osp.join(new_dir, | ||
osp.splitext(osp.basename(original_filepath))[0]) | ||
|
||
def process_single(self, sample, context=False): | ||
# check if it's generated already | ||
if self.frame_key in sample: | ||
return sample | ||
|
||
# there is no videos in this sample | ||
if self.video_key not in sample or not sample[self.video_key]: | ||
return [] | ||
|
||
# load videos | ||
loaded_video_keys = sample[self.video_key] | ||
sample, videos = load_data_with_context(sample, context, | ||
loaded_video_keys, load_video) | ||
video_to_frames = {} | ||
text = sample[self.text_key] | ||
offset = 0 | ||
|
||
for chunk in text.split(SpecialTokens.eoc): | ||
video_count = chunk.count(SpecialTokens.video) | ||
# no video or no text | ||
if video_count == 0 or len(chunk) == 0: | ||
continue | ||
else: | ||
for video_key in loaded_video_keys[offset:offset + | ||
video_count]: | ||
video = videos[video_key] | ||
# extract frame videos | ||
if self.frame_sampling_method == 'all_keyframes': | ||
frames = extract_key_frames(video) | ||
elif self.frame_sampling_method == 'uniform': | ||
frames = extract_video_frames_uniformly( | ||
video, self.frame_num) | ||
else: | ||
raise ValueError(f'Not support sampling method \ | ||
`{self.frame_sampling_method}`.') | ||
frames = [frame.to_image() for frame in frames] | ||
|
||
if self.frame_dir: | ||
frame_dir = osp.join( | ||
self.frame_dir, | ||
osp.splitext(osp.basename(video_key))[0]) | ||
else: | ||
# video path as frames directory | ||
frame_dir = self._get_default_frame_dir(video_key) | ||
os.makedirs(frame_dir, exist_ok=True) | ||
|
||
video_to_frames[video_key] = [] | ||
for i, frame in enumerate(frames): | ||
frame_path = osp.join( | ||
frame_dir, self.frame_fname_template.format(i)) | ||
if not os.path.exists(frame_path): | ||
frame.save(frame_path) | ||
|
||
video_to_frames[video_key].append(frame_path) | ||
|
||
offset += video_count | ||
|
||
if not context: | ||
for vid_key in videos: | ||
close_video(videos[vid_key]) | ||
|
||
sample[self.frame_key] = json.dumps(video_to_frames) | ||
# sample[self.frame_key] = video_to_frames | ||
|
||
return sample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
import os.path as osp | ||
import copy | ||
import unittest | ||
import json | ||
import tempfile | ||
import shutil | ||
from data_juicer.core.data import NestedDataset as Dataset | ||
from data_juicer.ops.mapper.video_extract_frames_mapper import \ | ||
VideoExtractFramesMapper | ||
from data_juicer.utils.constant import Fields | ||
from data_juicer.utils.mm_utils import SpecialTokens | ||
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase | ||
|
||
|
||
class VideoExtractFramesMapperTest(DataJuicerTestCaseBase): | ||
|
||
data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') | ||
vid1_path = os.path.join(data_path, 'video1.mp4') | ||
vid2_path = os.path.join(data_path, 'video2.mp4') | ||
vid3_path = os.path.join(data_path, 'video3.mp4') | ||
tmp_dir = tempfile.TemporaryDirectory().name | ||
|
||
def tearDown(self): | ||
super().tearDown() | ||
shutil.rmtree(self.tmp_dir) | ||
|
||
def _run_video_extract_frames_mapper(self, | ||
op, | ||
source_list, | ||
target_list, | ||
num_proc=1): | ||
dataset = Dataset.from_list(source_list) | ||
dataset = dataset.map(op.process, batch_size=2, num_proc=num_proc) | ||
res_list = dataset.to_list() | ||
self.assertEqual(res_list, target_list) | ||
|
||
def _get_frames_list(self, filepath, frame_dir, frame_num): | ||
frames_dir = osp.join(frame_dir, osp.splitext(osp.basename(filepath))[0]) | ||
frames_list = [osp.join(frames_dir, f'frame_{i}.jpg') for i in range(frame_num)] | ||
return frames_list | ||
|
||
def test_uniform_sampling(self): | ||
ds_list = [{ | ||
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', | ||
'videos': [self.vid1_path] | ||
}, { | ||
'text': | ||
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', | ||
'videos': [self.vid2_path] | ||
}, { | ||
'text': | ||
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', | ||
'videos': [self.vid3_path] | ||
}] | ||
frame_num = 3 | ||
frame_dir=os.path.join(self.tmp_dir, 'test1') | ||
|
||
tgt_list = copy.deepcopy(ds_list) | ||
tgt_list[0].update({Fields.video_frames: | ||
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, frame_num)})}) | ||
tgt_list[1].update({Fields.video_frames: | ||
json.dumps({self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, frame_num)})}) | ||
tgt_list[2].update({Fields.video_frames: | ||
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, frame_num)})}) | ||
|
||
op = VideoExtractFramesMapper( | ||
frame_sampling_method='uniform', | ||
frame_num=frame_num, | ||
frame_dir=frame_dir) | ||
self._run_video_extract_frames_mapper(op, ds_list, tgt_list) | ||
|
||
|
||
def test_all_keyframes_sampling(self): | ||
ds_list = [{ | ||
'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', | ||
'videos': [self.vid1_path] | ||
}, { | ||
'text': | ||
f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}' + \ | ||
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', | ||
'videos': [self.vid2_path, self.vid3_path] | ||
}, { | ||
'text': | ||
f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', | ||
'videos': [self.vid3_path] | ||
}] | ||
frame_dir=os.path.join(self.tmp_dir, 'test2') | ||
|
||
tgt_list = copy.deepcopy(ds_list) | ||
tgt_list[0].update({Fields.video_frames: | ||
json.dumps({self.vid1_path: self._get_frames_list(self.vid1_path, frame_dir, 3)})}) | ||
tgt_list[1].update({Fields.video_frames: json.dumps({ | ||
self.vid2_path: self._get_frames_list(self.vid2_path, frame_dir, 3), | ||
self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6) | ||
})}) | ||
tgt_list[2].update({Fields.video_frames: | ||
json.dumps({self.vid3_path: self._get_frames_list(self.vid3_path, frame_dir, 6)})}) | ||
|
||
op = VideoExtractFramesMapper( | ||
frame_sampling_method='all_keyframes', | ||
frame_dir=frame_dir) | ||
self._run_video_extract_frames_mapper(op, ds_list, tgt_list) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |