diff --git a/configs/config_all.yaml b/configs/config_all.yaml
index 1b19bb309..708db5c22 100644
--- a/configs/config_all.yaml
+++ b/configs/config_all.yaml
@@ -23,6 +23,12 @@ trace_num: 10 # number of samples
op_fusion: false # whether to fuse operators that share the same intermediate variables automatically. Op fusion might reduce the memory requirements slightly but speed up the whole process.
cache_compress: null # The compression method of the cache file, which can be specified in ['gzip', 'zstd', 'lz4']. If this parameter is None, the cache file will not be compressed. We recommend you turn on this argument when your input dataset is larger than tens of GB and your disk space is not enough.
+# for multimodal data processing
+image_key: 'images' # Key name of field to store the list of sample image paths.
+image_special_token: '<__dj__image>' # The special token that represents an image in the text. In default, it's "<__dj__image>". You can specify your own special token according to your input dataset.
+
+eoc_special_token: '<|__dj__eoc|>' # The special token that represents the end of a chunk in the text. In default, it's "<|__dj__eoc|>". You can specify your own special token according to your input dataset.
+
# for distributed processing
executor_type: default # Type of executor, support "default" or "ray" for now.
ray_address: auto # The address of the Ray cluster.
@@ -110,6 +116,10 @@ process:
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
+ - image_aspect_ratio_filter: # filter samples according to the aspect ratios of images (a fraction of width by height, r=w/h) in them
+ min_ratio: 0.333 # the min aspect ratio of filter range
+ max_ratio: 3.0 # the max aspect ratio of filter range
+ any_or_all: any # keep this sample when any/all images meet the filter condition
- language_id_score_filter: # filter text in specific language with language scores larger than a specific max value
lang: en # keep text in what language
min_score: 0.8 # the min language scores to filter text
diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py
index 3a5ff0ed0..f59d8cdc9 100644
--- a/data_juicer/config/config.py
+++ b/data_juicer/config/config.py
@@ -11,6 +11,7 @@
from data_juicer.ops.base_op import OPERATORS
from data_juicer.utils.logger_utils import setup_logger
+from data_juicer.utils.mm_utils import SpecialTokens
def init_configs(args=None):
@@ -102,6 +103,25 @@ def init_configs(args=None):
'requiring multiple keys, users can specify the op multiple '
'times. We will only use the first key of `text_keys` when you '
'set multiple keys.')
+ parser.add_argument(
+ '--image_key',
+ type=str,
+ default='images',
+ help='Key name of field to store the list of sample image paths.')
+ parser.add_argument(
+ '--image_special_token',
+ type=str,
+ default=SpecialTokens.image,
+ help='The special token that represents an image in the text. In '
+ 'default, it\'s "<__dj__image>". You can specify your own special'
+ ' token according to your input dataset.')
+ parser.add_argument(
+ '--eoc_special_token',
+ type=str,
+ default=SpecialTokens.eoc,
+ help='The special token that represents the end of a chunk in the '
+ 'text. In default, it\'s "<|__dj__eoc|>". You can specify your '
+ 'own special token according to your input dataset.')
parser.add_argument(
'--suffixes',
type=Union[str, List[str], Tuple[str]],
@@ -289,6 +309,19 @@ def init_setup_from_cfg(cfg):
filename=logfile_name,
redirect=cfg.executor_type == 'default')
+ # check and get dataset dir
+ if os.path.exists(cfg.dataset_path):
+ if os.path.isdir(cfg.dataset_path):
+ cfg.dataset_dir = os.path.abspath(cfg.dataset_path)
+ else:
+ cfg.dataset_dir = os.path.abspath(
+ os.path.dirname(cfg.dataset_path))
+ else:
+ logger.error(f'Input dataset_path [{cfg.dataset_path}] is invalid. '
+ f'Please check and retry.')
+ raise ValueError(f'Input dataset_path [{cfg.dataset_path}] is '
+ f'invalid. Please check and retry.')
+
# whether or not to use cache management
# disabling the cache or using checkpoint explicitly will turn off the
# cache management.
@@ -334,6 +367,10 @@ def init_setup_from_cfg(cfg):
cfg.add_suffix = True
break
+ # update special tokens
+ SpecialTokens.image = cfg.image_special_token
+ SpecialTokens.eoc = cfg.eoc_special_token
+
# Apply text_key modification during initializing configs
# users can freely specify text_key for different ops using `text_key`
# otherwise, set arg text_key of each op to text_keys
@@ -345,9 +382,13 @@ def init_setup_from_cfg(cfg):
for op_name in op:
args = op[op_name]
if args is None:
- args = {'text_key': text_key}
+ args = {
+ 'text_key': text_key,
+ 'image_key': cfg.image_key,
+ }
elif args['text_key'] is None:
args['text_key'] = text_key
+ args['image_key'] = cfg.image_key
op[op_name] = args
return cfg
diff --git a/data_juicer/core/analyser.py b/data_juicer/core/analyser.py
index a61d64b0b..c1b0b93af 100644
--- a/data_juicer/core/analyser.py
+++ b/data_juicer/core/analyser.py
@@ -73,7 +73,7 @@ def run(self, load_data_np=None):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
- dataset = self.formatter.load_dataset(load_data_np)
+ dataset = self.formatter.load_dataset(load_data_np, self.cfg)
# extract processes
logger.info('Preparing process operators...')
diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py
index eb32e306d..8fb0c8b2b 100644
--- a/data_juicer/core/executor.py
+++ b/data_juicer/core/executor.py
@@ -94,7 +94,7 @@ def run(self, load_data_np=None):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
- dataset = self.formatter.load_dataset(load_data_np)
+ dataset = self.formatter.load_dataset(load_data_np, self.cfg)
# 2. extract processes
logger.info('Preparing process operators...')
diff --git a/data_juicer/core/exporter.py b/data_juicer/core/exporter.py
index 9450c9482..1ca6f19bf 100644
--- a/data_juicer/core/exporter.py
+++ b/data_juicer/core/exporter.py
@@ -176,7 +176,7 @@ def export(self, dataset):
@staticmethod
def to_jsonl(dataset, export_path, num_proc=1, **kwargs):
"""
- Export method for json/jsonl target files.
+ Export method for jsonl target files.
:param dataset: the dataset to export.
:param export_path: the path to store the exported dataset.
@@ -186,6 +186,19 @@ def to_jsonl(dataset, export_path, num_proc=1, **kwargs):
"""
dataset.to_json(export_path, force_ascii=False, num_proc=num_proc)
+ @staticmethod
+ def to_json(dataset, export_path, num_proc=1, **kwargs):
+ """
+ Export method for json target files.
+
+ :param dataset: the dataset to export.
+ :param export_path: the path to store the exported dataset.
+ :param num_proc: the number of processes used to export the dataset.
+ :param kwargs: extra arguments.
+ :return:
+ """
+ dataset.to_json(export_path, force_ascii=False, num_proc=num_proc, lines=False)
+
@staticmethod
def to_parquet(dataset, export_path, **kwargs):
"""
@@ -208,6 +221,6 @@ def _router():
"""
return {
'jsonl': Exporter.to_jsonl,
- 'json': Exporter.to_jsonl,
+ 'json': Exporter.to_json,
'parquet': Exporter.to_parquet,
}
diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py
index 0a8629bfc..a297463b7 100644
--- a/data_juicer/format/formatter.py
+++ b/data_juicer/format/formatter.py
@@ -51,7 +51,9 @@ def __init__(
self.data_files = find_files_with_suffix(dataset_path, suffixes)
self.add_suffix = add_suffix
- def load_dataset(self, num_proc: int = 1) -> Dataset:
+ def load_dataset(self,
+ num_proc: int = 1,
+ global_cfg=None) -> Dataset:
"""
Load a dataset from dataset file or dataset directory, and unify its
format.
@@ -76,7 +78,8 @@ def load_dataset(self, num_proc: int = 1) -> Dataset:
concatenate_datasets([ds for _, ds in datasets.items()]))
ds = unify_format(datasets,
text_keys=self.text_keys,
- num_proc=num_proc)
+ num_proc=num_proc,
+ global_cfg=global_cfg)
return ds
@@ -100,7 +103,9 @@ def __init__(self,
self.text_keys = text_keys
self.kwargs = kwargs
- def load_dataset(self, num_proc: int = 1) -> Dataset:
+ def load_dataset(self,
+ num_proc: int = 1,
+ global_cfg=None) -> Dataset:
"""
Load a dataset from HuggingFace, and unify its format.
@@ -112,7 +117,10 @@ def load_dataset(self, num_proc: int = 1) -> Dataset:
split='train',
num_proc=num_proc,
**self.kwargs)
- ds = unify_format(ds, text_keys=self.text_keys, num_proc=num_proc)
+ ds = unify_format(ds,
+ text_keys=self.text_keys,
+ num_proc=num_proc,
+ global_cfg=global_cfg)
return ds
@@ -137,6 +145,7 @@ def unify_format(
dataset: Dataset,
text_keys: Union[List[str], str] = 'text',
num_proc: int = 1,
+ global_cfg=None,
) -> Dataset:
"""
Get an unified internal format, conduct the following modifications.
@@ -201,12 +210,40 @@ def non_empty_text(sample, target_keys):
fn_kwargs={'target_keys': text_keys})
logger.info(f'{len(dataset)} samples left after filtering empty text.')
- # 3. add Fields.stats field
- # TODO:
- # this is a temp solution,
- # it will occur errors when only call mapper ops
- # dataset = dataset.add_column( \
- # name=Fields.stats, column=[{}] * dataset.num_rows)
+ # 3. convert relative paths to absolute paths
+ if global_cfg:
+ logger.info('Converting relative paths in the dataset to their '
+ 'absolute version. (Based on the directory of input '
+ 'dataset file)')
+ ds_dir = global_cfg.dataset_dir
+ image_key = global_cfg.image_key
+
+ # function to convert relative paths to absolute paths
+ def rel2abs(sample, path_keys, dataset_dir):
+ for path_key in path_keys:
+ if path_key not in sample:
+ continue
+ paths = sample[path_key]
+ if not paths:
+ continue
+ new_paths = [os.path.join(dataset_dir, path)
+ for path in paths if not os.path.isabs(path)]
+ sample[path_key] = new_paths
+ return sample
+
+ dataset = dataset.map(rel2abs,
+ num_proc=num_proc,
+ fn_kwargs={
+ 'path_keys': [
+ image_key,
+ ],
+ 'dataset_dir': ds_dir
+ })
+ else:
+ logger.warning(f'No global config passed into unify_format function. '
+ f'Relative paths in the dataset might not be converted '
+ f'to their absolute versions. Data of other modalities '
+ f'might not be able to find by Data-Juicer.')
return dataset
@@ -262,6 +299,8 @@ def load_formatter(dataset_path,
# no data
else:
- raise ValueError('Can not found local data or huggingface '
- 'dataset-hub for your given path: '
- f'{dataset_path} and suffixes: {suffixes}')
+ raise ValueError(f'Unable to load the dataset from [{dataset_path}]. '
+ f'It might be because Data-Juicer doesn\'t support '
+ f'the format of this dataset, or the path of this '
+ f'dataset is incorrect.Please check if it\'s a valid '
+ f'dataset path and retry.')
diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py
index 17bd057ba..f55907f90 100644
--- a/data_juicer/format/mixture_formatter.py
+++ b/data_juicer/format/mixture_formatter.py
@@ -81,16 +81,17 @@ def _random_sample(self, dataset, weight=1.0, seed=None):
return dataset
return dataset.shuffle(seed=seed).select(range(num_samples))
- def load_dataset(self, num_proc: int = 1) -> Dataset:
+ def load_dataset(self, num_proc: int = 1, global_cfg=None) -> Dataset:
"""
Load a mixed dataset.
:param num_proc: number of processes when loading the dataset
+ :param global_cfg: the global cfg used in consequent processes,
:return: mixed dataset
"""
dataset_list = []
for weight, formatter in zip(self.weights, self.formatters):
- dataset = formatter.load_dataset(num_proc)
+ dataset = formatter.load_dataset(num_proc, global_cfg)
sampled = self._random_sample(dataset, weight)
logger.info(f'sampled {len(sampled)} from '
f'{len(dataset)} with weight {weight}')
diff --git a/data_juicer/format/text_formatter.py b/data_juicer/format/text_formatter.py
index fbca468cb..fdad34fac 100644
--- a/data_juicer/format/text_formatter.py
+++ b/data_juicer/format/text_formatter.py
@@ -96,11 +96,14 @@ def __init__(self,
self.dataset_path = dataset_path
self.add_suffix = add_suffix
- def load_dataset(self, num_proc: int = 1) -> Dataset:
+ def load_dataset(self,
+ num_proc: int = 1,
+ global_cfg=None) -> Dataset:
"""
Load a dataset from local text-type files.
:param num_proc: number of processes when loading the dataset
+ :param global_cfg: the global cfg used in consequent processes,
:return: unified_format_dataset.
"""
# extract text to cache directory
@@ -154,4 +157,5 @@ def load_dataset(self, num_proc: int = 1) -> Dataset:
datasets = concatenate_datasets([ds for _, ds in datasets.items()])
return unify_format(datasets,
text_keys=self.text_keys,
- num_proc=num_proc)
+ num_proc=num_proc,
+ global_cfg=global_cfg)
diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py
index 764feda64..be45a9673 100644
--- a/data_juicer/ops/base_op.py
+++ b/data_juicer/ops/base_op.py
@@ -2,22 +2,49 @@
OPERATORS = Registry('Operators')
-
-class Mapper:
-
- def __init__(self, text_key: str = None):
+class OP:
+ def __init__(self,
+ text_key: str = None,
+ image_key: str = None,
+ ):
"""
- Base class that conducts text editing.
+ Base class of operators.
:param text_key: the key name of field that stores sample texts
to be processed.
+ :param image_key: the key name of field that stores sample image list
+ to be processed
"""
+ # init data keys
if text_key is None:
text_key = 'text'
self.text_key = text_key
+ if image_key is None:
+ image_key = 'images'
+ self.image_key = image_key
+
from data_juicer.core.data import wrap_func_with_nested_access
self.process = wrap_func_with_nested_access(self.process)
+ def process(self, *args, **kwargs):
+ raise NotImplementedError
+
+class Mapper(OP):
+
+ def __init__(self,
+ text_key: str = None,
+ image_key: str = None,
+ ):
+ """
+ Base class that conducts data editing.
+
+ :param text_key: the key name of field that stores sample texts
+ to be processed.
+ :param image_key: the key name of field that stores sample image list
+ to be processed
+ """
+ super(Mapper, self).__init__(text_key, image_key)
+
# In default, it's a normal OP instead of batched OP
self._batched_op = False
@@ -34,20 +61,23 @@ def is_batched_op(self):
return self._batched_op
-class Filter:
+class Filter(OP):
- def __init__(self, text_key: str = None):
+ def __init__(self,
+ text_key: str = None,
+ image_key: str = None,
+ ):
"""
Base class that removes specific info.
:param text_key: the key name of field that stores sample texts
to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
"""
- if text_key is None:
- text_key = 'text'
- self.text_key = text_key
+ super(Filter, self).__init__(text_key, image_key)
+
from data_juicer.core.data import wrap_func_with_nested_access
- self.process = wrap_func_with_nested_access(self.process)
self.compute_stats = wrap_func_with_nested_access(self.compute_stats)
def compute_stats(self, sample, context=False):
@@ -72,20 +102,23 @@ def process(self, sample):
raise NotImplementedError
-class Deduplicator:
+class Deduplicator(OP):
- def __init__(self, text_key: str = None):
+ def __init__(self,
+ text_key: str = None,
+ image_key: str = None,
+ ):
"""
Base class that conducts deduplication.
:param text_key: the key name of field that stores sample texts
to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
"""
- if text_key is None:
- text_key = 'text'
- self.text_key = text_key
+ super(Deduplicator, self).__init__(text_key, image_key)
+
from data_juicer.core.data import wrap_func_with_nested_access
- self.process = wrap_func_with_nested_access(self.process)
self.compute_hash = wrap_func_with_nested_access(self.compute_hash)
def compute_hash(self, sample):
@@ -109,20 +142,21 @@ def process(self, dataset, show_num=0):
raise NotImplementedError
-class Selector:
+class Selector(OP):
- def __init__(self, text_key: str = None):
+ def __init__(self,
+ text_key: str = None,
+ image_key: str = None,
+ ):
"""
Base class that conducts selection in dataset-level.
:param text_key: the key name of field that stores sample texts
to be processed
+ :param image_key: the key name of field that stores sample image list
+ to be processed
"""
- if text_key is None:
- text_key = 'text'
- self.text_key = text_key
- from data_juicer.core.data import wrap_func_with_nested_access
- self.process = wrap_func_with_nested_access(self.process)
+ super(Selector, self).__init__(text_key, image_key)
def process(self, dataset):
"""
diff --git a/data_juicer/ops/filter/__init__.py b/data_juicer/ops/filter/__init__.py
index b6e48ac5a..c9332eea0 100644
--- a/data_juicer/ops/filter/__init__.py
+++ b/data_juicer/ops/filter/__init__.py
@@ -1,7 +1,8 @@
from . import (alphanumeric_filter, average_line_length_filter,
character_repetition_filter, flagged_words_filter,
- language_id_score_filter, maximum_line_length_filter,
- perplexity_filter, special_characters_filter,
- specified_field_filter, specified_numeric_field_filter,
- stopwords_filter, suffix_filter, text_length_filter,
- token_num_filter, word_num_filter, word_repetition_filter)
+ image_aspect_ratio_filter, language_id_score_filter,
+ maximum_line_length_filter, perplexity_filter,
+ special_characters_filter, specified_field_filter,
+ specified_numeric_field_filter, stopwords_filter, suffix_filter,
+ text_length_filter, token_num_filter, word_num_filter,
+ word_repetition_filter)
diff --git a/data_juicer/ops/filter/image_aspect_ratio_filter.py b/data_juicer/ops/filter/image_aspect_ratio_filter.py
new file mode 100644
index 000000000..0af4ec214
--- /dev/null
+++ b/data_juicer/ops/filter/image_aspect_ratio_filter.py
@@ -0,0 +1,97 @@
+
+import numpy as np
+
+from jsonargparse.typing import PositiveFloat
+
+from data_juicer.utils.constant import Fields, StatsKeys
+
+from ..base_op import OPERATORS, Filter
+from ..op_fusion import LOADED_IMAGES
+from data_juicer.utils.mm_utils import load_image
+
+
+@OPERATORS.register_module('image_aspect_ratio_filter')
+@LOADED_IMAGES.register_module('image_aspect_ratio_filter')
+class ImageAspectRatioFilter(Filter):
+ """Filter to keep samples with image aspect ratio within a specific range.
+ AspectRatio = W / H.
+ """
+
+ def __init__(self,
+ min_ratio: PositiveFloat = 0.333,
+ max_ratio: PositiveFloat = 3.0,
+ any_or_all: str = 'any',
+ *args,
+ **kwargs):
+ """
+ Initialization method.
+
+ :param min_ratio: The min aspect ratio to keep samples.
+ :param max_ratio: The max aspect ratio to keep samples.
+ :param any_or_all: keep this sample with 'any' or 'all' strategy of
+ all images. 'any': keep this sample if any images meet the
+ condition. 'all': keep this sample only if all images meet the
+ condition.
+ :param args: extra args
+ :param kwargs: extra args
+ """
+ super().__init__(*args, **kwargs)
+ self.min_ratio = min_ratio
+ self.max_ratio = max_ratio
+ if any_or_all not in ['any', 'all']:
+ raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
+ f'Can only be one of ["any", "all"].')
+ self.any = (any_or_all == 'any')
+
+ def compute_stats(self, sample, context=False):
+ # check if it's computed already
+ if StatsKeys.aspect_ratios in sample[Fields.stats]:
+ return sample
+
+ # there is no image in this sample
+ if self.image_key not in sample or not sample[self.image_key]:
+ sample[Fields.stats][StatsKeys.aspect_ratios] = np.array(
+ [], dtype=np.float64)
+ return sample
+
+ # load images
+ loaded_image_keys = sample[self.image_key]
+ images = {}
+ for loaded_image_key in loaded_image_keys:
+ if context and loaded_image_key in sample[Fields.context]:
+ # load from context
+ images[loaded_image_key] = sample[
+ Fields.context][loaded_image_key]
+ else:
+ if loaded_image_key not in images:
+ # avoid load the same images
+ image = load_image(loaded_image_key)
+ images[loaded_image_key] = image
+ if context:
+ # store the image data into context
+ sample[Fields.context][loaded_image_key] = image
+
+ # compute aspect ratios for each image with W/H
+ aspect_ratios = {
+ key: (images[key].width / images[key].height)
+ for key in images
+ }
+ sample[Fields.stats][StatsKeys.aspect_ratios] = [
+ aspect_ratios[key] for key in loaded_image_keys
+ ]
+ return sample
+
+ def process(self, sample):
+ aspect_ratios = sample[Fields.stats][StatsKeys.aspect_ratios]
+ keep_bools = np.array([
+ self.min_ratio <= aspect_ratio <= self.max_ratio
+ for aspect_ratio in aspect_ratios])
+ if len(keep_bools) <= 0:
+ return True
+
+ # different strategies
+ if self.any:
+ return keep_bools.any()
+ else:
+ return keep_bools.all()
+
diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py
index 099fc28d0..518bd01bc 100644
--- a/data_juicer/ops/op_fusion.py
+++ b/data_juicer/ops/op_fusion.py
@@ -8,9 +8,16 @@
from .base_op import Filter
# Type of intermediate vars
+# text
INTER_LINES = Registry(InterVars.lines)
INTER_WORDS = Registry(InterVars.words)
+# images
+LOADED_IMAGES = Registry(InterVars.loaded_images)
+
+# all
+ALL_INTER_VARS = [INTER_LINES, INTER_WORDS, LOADED_IMAGES]
+
def fuse_operators(process_list, ops):
"""
@@ -62,7 +69,7 @@ def fuse_filter_group(original_filter_group):
"""
fused_group_def = []
fused_group = []
- all_intermediate_vars = [INTER_LINES, INTER_WORDS]
+ all_intermediate_vars = ALL_INTER_VARS
all_fused_filters = {
inter_vars: []
for inter_vars in all_intermediate_vars
diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py
index 24a918999..b2f3a362f 100644
--- a/data_juicer/utils/constant.py
+++ b/data_juicer/utils/constant.py
@@ -9,6 +9,7 @@ class Fields(object):
class StatsKeys(object):
+ # text
alpha_token_ratio = 'alpha_token_ratio'
alnum_ratio = 'alnum_ratio'
avg_line_length = 'avg_line_length'
@@ -25,6 +26,9 @@ class StatsKeys(object):
num_words = 'num_words'
word_rep_ratio = 'word_rep_ratio'
+ # image
+ aspect_ratios = 'aspect_ratios'
+
class HashKeys(object):
hash = DEFAULT_PREFIX + 'hash'
@@ -33,6 +37,10 @@ class HashKeys(object):
class InterVars(object):
+ # text
lines = DEFAULT_PREFIX + 'lines'
words = DEFAULT_PREFIX + 'words'
refined_words = DEFAULT_PREFIX + 'refined_words'
+
+ # image
+ loaded_images = DEFAULT_PREFIX + 'loaded_images'
diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py
new file mode 100644
index 000000000..b67484062
--- /dev/null
+++ b/data_juicer/utils/mm_utils.py
@@ -0,0 +1,21 @@
+
+from datasets import Image
+
+from data_juicer.utils.constant import DEFAULT_PREFIX
+
+# A class to keep special tokens for multimodal information in the texts
+# The tokens in this class can be updated by corresponding arguments in config
+class SpecialTokens(object):
+ # modality
+ image = f'<{DEFAULT_PREFIX}image>'
+
+ # others
+ eoc = f'<|{DEFAULT_PREFIX}eoc|>'
+
+def load_images(paths):
+ return [load_image(path) for path in paths]
+
+def load_image(path):
+ img_feature = Image()
+ img = img_feature.decode_example(img_feature.encode_example(path))
+ return img
diff --git a/demos/overview_scan/app.py b/demos/overview_scan/app.py
index d1b109ac8..378b8f502 100644
--- a/demos/overview_scan/app.py
+++ b/demos/overview_scan/app.py
@@ -89,7 +89,7 @@
|-----------------------------------|:------:|-------------------------------------------------|
| Formatter | 7 | Discovers, loads, and canonicalizes source data |
| Mapper | 21 | Edits and transforms samples |
-| Filter | 16 | Filters out low-quality samples |
+| Filter | 17 | Filters out low-quality samples |
| Deduplicator | 3 | Detects and removes duplicate samples |
| Selector | 2 | Selects top samples based on ranking |
'''
@@ -141,6 +141,7 @@
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
+| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold |
diff --git a/docs/Operators.md b/docs/Operators.md
index 107dc57d1..78abeb495 100644
--- a/docs/Operators.md
+++ b/docs/Operators.md
@@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 21 | Edits and transforms samples |
-| [ Filter ]( #filter ) | 16 | Filters out low-quality samples |
+| [ Filter ]( #filter ) | 17 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 3 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 2 | Selects top samples based on ranking |
@@ -76,6 +76,7 @@ All the specific operators are listed below, each featured with several capabili
| average_line_length_filter | Code | en, zh | Keeps samples with average line length within the specified range |
| character_repetition_filter | General | en, zh | Keeps samples with char-level n-gram repetition ratio within the specified range |
| flagged_words_filter | General | en, zh | Keeps samples with flagged-word ratio below the specified threshold |
+| image_aspect_ratio_filter | Image | - | Keeps samples contains images with aspect ratios within specific range |
| language_id_score_filter | General | en, zh | Keeps samples of the specified language, judged by a predicted confidence score |
| maximum_line_length_filter | Code | en, zh | Keeps samples with maximum line length within the specified range |
| perplexity_filter | General | en, zh | Keeps samples with perplexity score below the specified threshold |
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md
index bd9f5d95a..cf3421d94 100644
--- a/docs/Operators_ZH.md
+++ b/docs/Operators_ZH.md
@@ -10,7 +10,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 21 | 对数据样本进行编辑和转换 |
-| [ Filter ]( #filter ) | 16 | 过滤低质量样本 |
+| [ Filter ]( #filter ) | 17 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 3 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 2 | 基于排序选取高质量样本 |
@@ -66,24 +66,25 @@ Data-Juicer 中的算子分为以下 5 种类型。
## Filter
-| 算子 | 场景 | 语言 | 描述 |
-|--------------------------------|----------|---------|------------------------------------|
-| alphanumeric_filter | General | en, zh | 保留字母数字比例在指定范围内的样本 |
-| average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 |
-| character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 |
-| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
-| language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 |
-| maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 |
-| perplexity_filter | General | en, zh | 保留困惑度低于指定阈值的样本 |
-| special_characters_filter | General | en, zh | 保留 special-char 比率的在指定范围内的样本 |
-| specified_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定目标中 |
-| specified_numeric_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定范围(针对数字类型) |
-| stopwords_filter | General | en, zh | 保留停用词比率高于指定阈值的样本 |
-| suffix_filter | General | en, zh | 保留包含特定后缀的样本 |
-| text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 |
-| token_num_filter | General | en, zh | 保留token数在指定范围内的样本 |
-| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
-| word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 |
+| 算子 | 场景 | 语言 | 描述 |
+|--------------------------------|---------|--------|------------------------------------|
+| alphanumeric_filter | General | en, zh | 保留字母数字比例在指定范围内的样本 |
+| average_line_length_filter | Code | en, zh | 保留平均行长度在指定范围内的样本 |
+| character_repetition_filter | General | en, zh | 保留 char-level n-gram 重复比率在指定范围内的样本 |
+| flagged_words_filter | General | en, zh | 保留使标记字比率保持在指定阈值以下的样本 |
+| image_aspect_ratio_filter | Image | - | 保留样本中包含的图片的宽高比在指定范围内的样本 |
+| language_id_score_filter | General | en, zh | 保留特定语言的样本,通过预测的置信度得分来判断 |
+| maximum_line_length_filter | Code | en, zh | 保留最大行长度在指定范围内的样本 |
+| perplexity_filter | General | en, zh | 保留困惑度低于指定阈值的样本 |
+| special_characters_filter | General | en, zh | 保留 special-char 比率的在指定范围内的样本 |
+| specified_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定目标中 |
+| specified_numeric_field_filter | General | en, zh | 根据字段过滤样本,要求字段的值处于指定范围(针对数字类型) |
+| stopwords_filter | General | en, zh | 保留停用词比率高于指定阈值的样本 |
+| suffix_filter | General | en, zh | 保留包含特定后缀的样本 |
+| text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 |
+| token_num_filter | General | en, zh | 保留token数在指定范围内的样本 |
+| word_num_filter | General | en, zh | 保留字数在指定范围内的样本 |
+| word_repetition_filter | General | en, zh | 保留 word-level n-gram 重复比率在指定范围内的样本 |
## Deduplicator
diff --git a/tests/config/demo_4_test.yaml b/tests/config/demo_4_test.yaml
index 39d11fd8f..0fe834613 100644
--- a/tests/config/demo_4_test.yaml
+++ b/tests/config/demo_4_test.yaml
@@ -2,7 +2,7 @@
# global parameters
project_name: 'test_demo'
-dataset_path: './demo/demo-dataset.jsonl' # path to your dataset directory or file
+dataset_path: './demos/data/demo-dataset.jsonl' # path to your dataset directory or file
np: 4 # number of subprocess to process your dataset
export_path: './outputs/demo/demo-processed.parquet'
diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py
index 231121345..913081f52 100644
--- a/tests/config/test_config_funcs.py
+++ b/tests/config/test_config_funcs.py
@@ -37,14 +37,16 @@ def test_yaml_cfg_file(self):
self.assertDictEqual(
cfg.process[0],
{'whitespace_normalization_mapper': {
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}}, 'nested dict load fail, for nonparametric op')
self.assertDictEqual(
cfg.process[1], {
'language_id_score_filter': {
'lang': 'zh',
'min_score': 0.8,
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}
}, 'nested dict load fail, un-expected internal value')
@@ -74,7 +76,8 @@ def test_mixture_cfg(self):
'language_id_score_filter': {
'lang': 'zh',
'min_score': 0.8,
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}
})
self.assertDictEqual(
@@ -82,7 +85,8 @@ def test_mixture_cfg(self):
'language_id_score_filter': {
'lang': 'en',
'min_score': 0.8,
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}
})
self.assertDictEqual(
@@ -90,7 +94,8 @@ def test_mixture_cfg(self):
'language_id_score_filter': {
'lang': 'fr',
'min_score': 0.8,
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}
})
self.assertDictEqual(
@@ -98,7 +103,8 @@ def test_mixture_cfg(self):
'language_id_score_filter': {
'lang': 'zh',
'min_score': 0.6,
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}
})
self.assertDictEqual(
@@ -106,7 +112,8 @@ def test_mixture_cfg(self):
'language_id_score_filter': {
'lang': 'en',
'min_score': 0.5,
- 'text_key': 'text'
+ 'text_key': 'text',
+ 'image_key': 'images',
}
})
diff --git a/tests/ops/data/img1.png b/tests/ops/data/img1.png
new file mode 100644
index 000000000..8d9e70b8e
Binary files /dev/null and b/tests/ops/data/img1.png differ
diff --git a/tests/ops/data/img2.jpg b/tests/ops/data/img2.jpg
new file mode 100644
index 000000000..8595513ad
Binary files /dev/null and b/tests/ops/data/img2.jpg differ
diff --git a/tests/ops/data/img3.jpg b/tests/ops/data/img3.jpg
new file mode 100644
index 000000000..e0de8b1c6
Binary files /dev/null and b/tests/ops/data/img3.jpg differ
diff --git a/tests/ops/filter/test_image_aspect_ratio_filter.py b/tests/ops/filter/test_image_aspect_ratio_filter.py
new file mode 100644
index 000000000..3d5ea6cf4
--- /dev/null
+++ b/tests/ops/filter/test_image_aspect_ratio_filter.py
@@ -0,0 +1,123 @@
+import os
+import unittest
+import numpy as np
+import PIL.Image
+
+from datasets import Dataset, Image
+
+from data_juicer.ops.filter.image_aspect_ratio_filter import \
+ ImageAspectRatioFilter
+from data_juicer.utils.constant import Fields
+
+
+class ImageAspectRatioFilterTest(unittest.TestCase):
+
+ data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
+ '..', 'data')
+ img1_path = os.path.join(data_path, 'img1.png')
+ img2_path = os.path.join(data_path, 'img2.jpg')
+ img3_path = os.path.join(data_path, 'img3.jpg')
+
+ def _run_image_aspect_ratio_filter(self,
+ dataset: Dataset, target_list,
+ op):
+ if Fields.stats not in dataset.features:
+ dataset = dataset.add_column(name=Fields.stats,
+ column=[{}] * dataset.num_rows)
+ dataset = dataset.map(op.compute_stats)
+ dataset = dataset.filter(op.process)
+ dataset = dataset.select_columns(column_names=[op.image_key])
+ res_list = dataset.to_list()
+ self.assertEqual(res_list, target_list)
+
+ def test_filter1(self):
+
+ ds_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = ImageAspectRatioFilter(min_ratio=0.8, max_ratio=1.2)
+ self._run_image_aspect_ratio_filter(dataset, tgt_list, op)
+
+ def test_filter2(self):
+
+ ds_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = ImageAspectRatioFilter(min_ratio=0.8)
+ self._run_image_aspect_ratio_filter(dataset, tgt_list, op)
+
+ def test_filter3(self):
+
+ ds_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img2_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path]
+ }, {
+ 'images': [self.img3_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = ImageAspectRatioFilter(max_ratio=1.2)
+ self._run_image_aspect_ratio_filter(dataset, tgt_list, op)
+
+ def test_any(self):
+
+ ds_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }, {
+ 'images': [self.img2_path, self.img3_path]
+ }, {
+ 'images': [self.img1_path, self.img3_path]
+ }]
+ tgt_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }, {
+ 'images': [self.img1_path, self.img3_path]
+ }]
+ dataset = Dataset.from_list(ds_list)
+ op = ImageAspectRatioFilter(min_ratio=0.8,
+ max_ratio=1.2,
+ any_or_all='any')
+ self._run_image_aspect_ratio_filter(dataset, tgt_list, op)
+
+ def test_all(self):
+
+ ds_list = [{
+ 'images': [self.img1_path, self.img2_path]
+ }, {
+ 'images': [self.img2_path, self.img3_path]
+ }, {
+ 'images': [self.img1_path, self.img3_path]
+ }]
+ tgt_list = []
+ dataset = Dataset.from_list(ds_list)
+ op = ImageAspectRatioFilter(min_ratio=0.8,
+ max_ratio=1.2,
+ any_or_all='all')
+ self._run_image_aspect_ratio_filter(dataset, tgt_list, op)
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/multimodal/README.md b/tools/multimodal/README.md
new file mode 100644
index 000000000..b9175c27c
--- /dev/null
+++ b/tools/multimodal/README.md
@@ -0,0 +1,93 @@
+# Multimodal Tools
+
+This folder contains some scripts and tools for multimodal datasets before and after using Data-Juicer.
+
+## Dataset Format Conversion
+
+Due to large format diversity among different multimodal datasets and works,
+Data-Juicer propose a novel intermediate format for multimodal dataset and
+provided several dataset format conversion tools for some popular multimodal
+works.
+
+These tools consist of two types:
+- Other format to Data-Juicer format: These tools are in `source_format_to_data_juicer_format` directory. They help to convert datasets in other formats to target datasets in Data-Juicer format.
+- Data-Juicer format to other format: These tools are in `data_juicer_format_to_target_format` directory. They help to convert datasets in Data-Juicer formats to target datasets in target format.
+
+For now, dataset formats that are supported by Data-Juicer are listed in the following table.
+
+| Format | source_format_to_data_juicer_format | data_juicer_format_to_target_format | Ref. |
+|------------|-------------------------------------|-------------------------------------|------------------------------------------------------------------------------------------------------------------|
+| LLaVA-like | `llava_to_dj.py` | `dj_to_llava.py` | [Format Description](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) |
+
+For all tools, you can run the following command to find out the usage of them:
+
+```shell
+# e.g. llava_to_dj.py
+python tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --help
+```
+
+Before using these tools, you might need to take a glance at the reference
+materials in the above tables for each format, to better know the detail format
+information and understand the arguments for each tool.
+
+### Notice
+There might be some tiny differences after converting a source dataset to Data-Juicer
+format and convert it back. However, these differences have nearly no effects
+on the semantics of datasets. Here we will show these tiny differences in detail
+for each source format.
+
+#### LLaVA-like
+The format of LLaVA-like datasets are defined [here](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format).
+Although it's simple, but in real scenarios, there might be some slight variations
+in some samples.
+
+Here we take the [visual instruction tuning dataset](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json) as an example,
+and show how these variations influence the dataset format. The table below
+shows the number of different samples between the original dataset and the
+dataset after processing. There are 665,298 samples in the original dataset.
+
+| process | # of diff. |
+|----------------------------------------------------------------------------------------|-------------|
+| 1. apply `llava_to_dj.py` and `dj_to_llava.py` | 113,501 |
+| 2. convert integer ids to string ids in the original dataset | 41,361 |
+| 3. strip whitespaces before and after values of conversations in the original dataset | 40,688 |
+| 4. add `'model': ''` fields in the converted dataset | 1 |
+
+It's worth noticing that processes 2-4 won't influence the semantics of sample conversations in the dataset.
+Thus we think the dataset after conversion can align with the original dataset.
+
+Finally, the only 1 sample is different because there are some extra useless fields ("text", "markdown")
+in the conversations, which is shown below. But the "from" and "value" fields are the same between original
+and converted datasets, so we can regard this sample is aligned with the original one as well.
+
+```json
+# original conversations
+[
+ {
+ "from":"human",
+ "value":"sentry self hosted api 504"
+ },
+ {
+ "from":"gpt",
+ "text":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\nSlow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n\nServer overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n\nLarge dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n\nAPI endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\nCheck your network connection and ensure that it is stable.\n\nVerify that the server hosting the API is not overloaded with requests.\n\nOptimize the query to retrieve only the necessary data.\n\nTry using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance.",
+ "value":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\n1. Slow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n2. Server overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n3. Large dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n4. API endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\n1. Check your network connection and ensure that it is stable.\n2. Verify that the server hosting the API is not overloaded with requests.\n3. Optimize the query to retrieve only the necessary data.\n4. Try using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance.",
+ "markdown":{
+ "type":"answer-markdown",
+ "index":1,
+ "answer":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\n1. Slow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n\n2. Server overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n\n3. Large dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n\n4. API endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\n1. Check your network connection and ensure that it is stable.\n\n2. Verify that the server hosting the API is not overloaded with requests.\n\n3. Optimize the query to retrieve only the necessary data.\n\n4. Try using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance."
+ }
+ }
+]
+
+# converted conversations
+[
+ {
+ "from":"human",
+ "value":"sentry self hosted api 504"
+ },
+ {
+ "from":"gpt",
+ "value":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\n1. Slow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n2. Server overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n3. Large dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n4. API endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\n1. Check your network connection and ensure that it is stable.\n2. Verify that the server hosting the API is not overloaded with requests.\n3. Optimize the query to retrieve only the necessary data.\n4. Try using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance."
+ }
+]
+```
diff --git a/tools/multimodal/README_ZH.md b/tools/multimodal/README_ZH.md
new file mode 100644
index 000000000..9eb7757ce
--- /dev/null
+++ b/tools/multimodal/README_ZH.md
@@ -0,0 +1,76 @@
+# 多模态工具
+
+这个文件夹包含了一些在使用 Data-Juicer 之前和之后可以用上的多模态数据集处理脚本和工具。
+
+## 数据集格式转换
+
+由于不同多模态数据集和工作之间的数据集格式差异较大,Data-Juicer 提出了一种新颖的多模态数据集中间格式,并为一些流行的多模态工作提供了若干数据集格式转换工具。
+
+这些工具分为两种类型:
+- 其他格式到 Data-Juicer 格式的转换:这些工具在 `source_format_to_data_juicer_format` 目录中。它们可以帮助将其他格式的数据集转换为 Data-Juicer 格式的目标数据集。
+- Data-Juicer 格式到其他格式的转换:这些工具在 `data_juicer_format_to_target_format` 目录中。它们可以帮助将 Data-Juicer 格式的数据集转换为目标格式的数据集。
+
+目前,Data-Juicer 支持的数据集格式在下面表格中列出。
+
+| 格式 | source_format_to_data_juicer_format | data_juicer_format_to_target_format | 格式参考 |
+|-----------|-------------------------------------|-------------------------------------|----------------------------------------------------------------------------------------------------|
+| 类LLaVA格式 | `llava_to_dj.py` | `dj_to_llava.py` | [格式描述](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) |
+
+对于所有工具,您可以运行以下命令来了解它们的详细用法:
+
+```shell
+# 例如:llava_to_dj.py
+python tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py --help
+```
+在使用这些工具之前,您可能需要查看上表中每个格式的参考资料,以更好地了解详细的格式信息,并理解每个工具的参数含义。
+
+### 注意事项
+将源数据集转换为 Data-Juicer 格式并再次转换回来后,可能会有一些微小的差异。然而,这些差异几乎不会影响数据集的语义信息。下面我们将详细展示每个支持的源格式中可能存在的这些微小差异。
+
+#### 类LLaVA格式
+类LLaVA格式数据集的格式在 [这里](https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md#dataset-format) 有具体的定义。尽管它很简单,但在实际场景中,某些样本可能会出现轻微的变体。
+
+这里我们以LLaVA的 [视觉指令微调数据集](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json) 为例,展示这些变化如何影响数据集的格式。下表显示了原始数据集和经过若干处理后数据集之间不同样本的数量。原始数据集中有665,298个样本。
+
+| 处理过程 | 不同样本数目 |
+|-------------------------------------------|---------|
+| 1. 运行 `llava_to_dj.py` 和 `dj_to_llava.py` | 113,501 |
+| 2. 将源数据集的id字段由整型转为字符串类型 | 41,361 |
+| 3. 将源数据集中对话的所有value字段前后的空格去除 | 40,688 |
+| 4. 在转换后的数据集样本中添加 `'model': ''` 字段 | 1 |
+
+值得注意的是,处理过程 2-4 并不会影响数据集中样本对话的语义,因此我们可以认为数据集格式转换工具的转换结果能够对齐源数据集。
+
+最后,只有1个样本不同。如下所示,原因为其对话中包含一些额外的无用字段("text","markdown")。但是,原始数据集和转换后的数据集之间的"from"和"value"字段是相同的,因此可以认为这个样本也是对齐的。
+
+```json
+# 原始对话
+[
+ {
+ "from":"human",
+ "value":"sentry self hosted api 504"
+ },
+ {
+ "from":"gpt",
+ "text":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\nSlow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n\nServer overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n\nLarge dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n\nAPI endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\nCheck your network connection and ensure that it is stable.\n\nVerify that the server hosting the API is not overloaded with requests.\n\nOptimize the query to retrieve only the necessary data.\n\nTry using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance.",
+ "value":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\n1. Slow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n2. Server overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n3. Large dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n4. API endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\n1. Check your network connection and ensure that it is stable.\n2. Verify that the server hosting the API is not overloaded with requests.\n3. Optimize the query to retrieve only the necessary data.\n4. Try using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance.",
+ "markdown":{
+ "type":"answer-markdown",
+ "index":1,
+ "answer":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\n1. Slow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n\n2. Server overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n\n3. Large dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n\n4. API endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\n1. Check your network connection and ensure that it is stable.\n\n2. Verify that the server hosting the API is not overloaded with requests.\n\n3. Optimize the query to retrieve only the necessary data.\n\n4. Try using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance."
+ }
+ }
+]
+
+# 转换后的对话
+[
+ {
+ "from":"human",
+ "value":"sentry self hosted api 504"
+ },
+ {
+ "from":"gpt",
+ "value":"A 504 error when using the Sentry self-hosted API typically indicates that the API request timed out. This means that the server hosting the API did not receive a response from the client within the expected time frame.\n\nHere are a few possible reasons for a 504 error when using the Sentry self-hosted API:\n\n1. Slow network connection: The API request may be taking longer than expected due to a slow network connection. Check your network connection and try again.\n2. Server overload: The server hosting the API may be overloaded with requests. Try again later when the server load has decreased.\n3. Large dataset: If you are trying to retrieve a large amount of data, the API request may take longer than expected. Try breaking up the request into smaller chunks or optimizing the query to retrieve only the necessary data.\n4. API endpoint issues: There may be issues with the specific API endpoint you are using. Try using a different endpoint or contact the Sentry support team for assistance.\n\nTo resolve a 504 error when using the Sentry self-hosted API, try the following:\n\n1. Check your network connection and ensure that it is stable.\n2. Verify that the server hosting the API is not overloaded with requests.\n3. Optimize the query to retrieve only the necessary data.\n4. Try using a different API endpoint.\n\nIf the issue persists, contact the Sentry support team for assistance."
+ }
+]
+```
diff --git a/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py b/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py
new file mode 100644
index 000000000..9df971e80
--- /dev/null
+++ b/tools/multimodal/data_juicer_format_to_target_format/dj_to_llava.py
@@ -0,0 +1,240 @@
+# This tool is used to convert multimodal dataset in LLaVA format to a target
+# dataset in Data-Juicer format.
+#
+# Corresponding Data-Juicer format:
+# - multi-chunk interleaved image-text sequence
+# - in jsonl
+# {'id': '000000033471',
+# 'images': ['coco/train2017/000000033471.jpg'],
+# 'text': '[[human]]: \n'
+# 'What are the colors of the bus in the image?\n'
+# '[[gpt]]: The bus in the image is white and red.\n'
+# '[[human]]: What feature can be seen on the back of the bus?\n'
+# '[[gpt]]: The back of the bus features an advertisement.\n'
+# '[[human]]: Is the bus driving down the street or pulled off to'
+# 'the side?\n'
+# '[[gpt]]: The bus is driving down the street, which is crowded '
+# 'with people and other vehicles. <|__dj__eoc|>'}
+#
+# LLaVA format:
+# - single/multi-turn conversation
+# - in json
+# [
+# {
+# "id": "000000033471",
+# "image": "coco/train2017/000000033471.jpg",
+# "conversations": [
+# {
+# "from": "human",
+# "value": "\nWhat are the colors of the bus in the image?"
+# },
+# {
+# "from": "gpt",
+# "value": "The bus in the image is white and red."
+# },
+# {
+# "from": "human",
+# "value": "What feature can be seen on the back of the bus?"
+# },
+# {
+# "from": "gpt",
+# "value": "The back of the bus features an advertisement."
+# },
+# {
+# "from": "human",
+# "value": "Is the bus driving down the street or pulled off to the side?"
+# },
+# {
+# "from": "gpt",
+# "value": "The bus is driving down the street, which is crowded with people and other vehicles."
+# }
+# ]
+# },
+# ...
+# ]
+#
+# Reference:
+# https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md
+
+import os
+import fire
+import json
+import jsonlines as jl
+import regex as re
+
+from tqdm import tqdm
+from loguru import logger
+
+from data_juicer.utils.mm_utils import SpecialTokens
+
+@logger.catch
+def main(
+ dj_ds_path: str,
+ target_llava_ds_path: str,
+ keep_only_first_image: bool = True,
+ eoc_special_token: str = SpecialTokens.eoc,
+ image_special_token: str = '',
+ sent_seperator: str = '\n',
+ convert_to_relative_paths: bool = False,
+ original_llava_ds_path: str = None,
+):
+ """
+ Convert a Data-Juicer-format dataset to a LLaVA-like dataset.
+
+ :param dj_ds_path: path to the input dataset in Data-Juicer format.
+ :param target_llava_ds_path: path to store the converted dataset in LLaVA
+ format.
+ :param keep_only_first_image: whether to only keep the image token in the
+ first conversation round. Default: True.
+ :param eoc_special_token: the special token for "end of a chunk". It's used
+ to split conversation chunks explicitly. Default: <|__dj__eoc|> (from
+ Data-Juicer).
+ :param image_special_token: the special token for images. It's used to
+ locate the images in the conversation. In typical LLaVA-like datasets,
+ this token always be "". You can change it to align with your
+ own LLaVA-like datasets but should be careful of possible compatibility
+ problems that come from this change. Default: .
+ :param sent_seperator: seperator to split different sentences. Default: \n.
+ :param convert_to_relative_paths: whether convert the image paths in this
+ dataset to relative paths to the original dataset. If it's True, an
+ extra argument original_llava_ds_path is required. When the processed
+ and converted dataset will be used in another machine, it's better to
+ set this argument to True. Default: False.
+ :param original_llava_ds_path: path to the original unprocessed llava
+ dataset, which is used to help to recover the relative image paths for
+ better migration. Default: None.
+ """
+ # ----- Constant settings. Better not to change them. -----
+ # default key of field to store the sample text
+ text_key = 'text'
+ # default key of field to store the image list
+ image_key = 'images'
+ # default pattern for the conversation role
+ from_pattern = re.compile(r'\[\[([a-zA-Z]*?)\]\]: ')
+ # ----- Constant settings. Better not to change them. -----
+ # check arguments
+ # check paths
+ if not os.path.exists(dj_ds_path):
+ raise FileNotFoundError(
+ f'Input dataset [{dj_ds_path}] can not be found.')
+ if not target_llava_ds_path.endswith('.json'):
+ raise ValueError(
+ f'Only support "json" target dataset file for LLaVA now.')
+ if os.path.dirname(target_llava_ds_path) \
+ and not os.path.exists(os.path.dirname(target_llava_ds_path)):
+ logger.info(
+ f'Create directory [{os.path.dirname(target_llava_ds_path)}] for '
+ f'the target dataset.')
+ os.makedirs(os.path.dirname(target_llava_ds_path))
+
+ # check if the default image special token is changed
+ if image_special_token != '':
+ logger.warning(f'The image_special_token used in the original LLaVA '
+ f'dataset is "". It\'s better to align the this '
+ f'token. There might be some compatibility problem if '
+ f'you change it.')
+
+ # if convert_to_relative_paths is True, check if the original_llava_ds_path
+ # is provided as well.
+ if convert_to_relative_paths:
+ if not original_llava_ds_path:
+ raise ValueError(f'When convert_to_relative_paths is set to True, '
+ f'the original_llava_ds_path must be provided '
+ f'for recovering the relative paths. Please '
+ f'check and retry.')
+ original_llava_ds_path = os.path.abspath(original_llava_ds_path)
+ # if provided original_llava_ds_path is the dataset file path, only
+ # keep the directory path.
+ if os.path.isfile(original_llava_ds_path):
+ original_llava_ds_path = os.path.dirname(original_llava_ds_path)
+
+ logger.info('Start to convert.')
+ samples = []
+ with jl.open(dj_ds_path, 'r') as reader:
+ for sample in tqdm(reader):
+ id = sample['id']
+ images = list(set(sample.get(image_key, [])))
+ text = sample[text_key]
+
+ if len(images) > 1:
+ raise ValueError(f'There are more than 1 distinct images in '
+ f'the sample with id [{id}], which is not '
+ f'compatible with LLaVA dataset format. '
+ f'Please check and fix it and retry.')
+
+ # convert dj text format to LLaVA conversation format
+ # split the text into a list of:
+ # [role1, sent1, role2, sent2, role1, sent3, role2, sent4, ...]
+ parts = from_pattern.split(text)
+ if parts[0] == '':
+ parts = parts[1:]
+ if len(parts) % 4 != 0:
+ raise ValueError(f'The conversations in the sample text with '
+ f'id [{id}] contains unbalance (human, '
+ f'robot) conversation round (number of '
+ f'conversation is [{len(parts)}]). Please '
+ f'check and fix the dataset and retry.')
+
+ conversations = []
+ # the number of sentences
+ num_sent = len(parts) // 2
+ for i in range(num_sent):
+ # get role and its sentence
+ role = parts[2 * i]
+ sent = parts[2 * i + 1].strip()
+
+ # remove sentence seperator
+ if sent.endswith(sent_seperator):
+ sent = sent[:-len(sent_seperator)].strip()
+ # remove possible eoc_special_tokens
+ if sent.endswith(eoc_special_token):
+ sent = sent[:-len(eoc_special_token)].strip()
+ # remove possible image_special_tokens when only keeping it in
+ # the first conversation round
+ if i > 0 and keep_only_first_image:
+ if sent.startswith(image_special_token):
+ sent = sent[len(image_special_token):].strip()
+ if sent.startswith(sent_seperator):
+ sent = sent[len(sent_seperator):].strip()
+ if sent.endswith(image_special_token):
+ sent = sent[:-len(image_special_token)].strip()
+ if sent.endswith(sent_seperator):
+ sent = sent[:-len(sent_seperator)].strip()
+
+ conversation = {
+ 'from': role,
+ 'value': sent
+ }
+ conversations.append(conversation)
+
+ # make up the new sample
+ new_sample = {
+ 'id': id,
+ 'conversations': conversations
+ }
+ if len(images) == 1:
+ image_path = images[0]
+ if convert_to_relative_paths:
+ if image_path.startswith(original_llava_ds_path):
+ image_path = os.path.relpath(image_path,
+ original_llava_ds_path)
+ else:
+ raise ValueError(f'The original_llava_ds_path '
+ f'[{original_llava_ds_path}] is not '
+ f'the directory that contains the '
+ f'image [{image_path}] in the sample '
+ f'with id [{id}]. Please check if '
+ f'the correct original_llava_ds_path '
+ f'is provided or something wrong '
+ f'with this sample, and try again '
+ f'later.')
+ new_sample['image'] = image_path
+ samples.append(new_sample)
+
+ logger.info(f'Start to write the converted dataset to '
+ f'[{target_llava_ds_path}]...')
+ json.dump(samples, open(target_llava_ds_path, 'w', encoding='utf-8'))
+
+
+if __name__ == '__main__':
+ fire.Fire(main)
diff --git a/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py b/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py
new file mode 100644
index 000000000..4cff127e7
--- /dev/null
+++ b/tools/multimodal/source_format_to_data_juicer_format/llava_to_dj.py
@@ -0,0 +1,275 @@
+# This tool is used to convert multimodal dataset in LLaVA format to a target
+# dataset in Data-Juicer format.
+#
+# LLaVA format:
+# - single/multi-turn conversation
+# - in json
+# [
+# {
+# "id": "000000033471",
+# "image": "coco/train2017/000000033471.jpg",
+# "conversations": [
+# {
+# "from": "human",
+# "value": "\nWhat are the colors of the bus in the image?"
+# },
+# {
+# "from": "gpt",
+# "value": "The bus in the image is white and red."
+# },
+# {
+# "from": "human",
+# "value": "What feature can be seen on the back of the bus?"
+# },
+# {
+# "from": "gpt",
+# "value": "The back of the bus features an advertisement."
+# },
+# {
+# "from": "human",
+# "value": "Is the bus driving down the street or pulled off to the side?"
+# },
+# {
+# "from": "gpt",
+# "value": "The bus is driving down the street, which is crowded with people and other vehicles."
+# }
+# ]
+# },
+# ...
+# ]
+#
+# Corresponding Data-Juicer format:
+# - multi-chunk interleaved image-text sequence
+# - in jsonl
+# {'id': '000000033471',
+# 'images': ['coco/train2017/000000033471.jpg'],
+# 'text': '[[human]]: \n'
+# 'What are the colors of the bus in the image?\n'
+# '[[gpt]]: The bus in the image is white and red.\n'
+# '[[human]]: What feature can be seen on the back of the bus?\n'
+# '[[gpt]]: The back of the bus features an advertisement.\n'
+# '[[human]]: Is the bus driving down the street or pulled off to'
+# 'the side?\n'
+# '[[gpt]]: The bus is driving down the street, which is crowded '
+# 'with people and other vehicles. <|__dj__eoc|>'}
+#
+# Reference:
+# https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md
+
+import os
+import fire
+import random
+import json
+import jsonlines as jl
+
+from tqdm import tqdm
+from loguru import logger
+
+from data_juicer.utils.mm_utils import SpecialTokens
+
+
+@logger.catch
+def main(
+ llava_ds_path: str,
+ target_ds_path: str,
+ str_id: bool = True,
+ split_chunk: bool = False,
+ image_broadcast: bool = False,
+ image_broadcast_pos: str = 'random',
+ eoc_special_token: str = SpecialTokens.eoc,
+ image_special_token: str = '',
+ add_eoc_at_last: bool = True,
+ sent_seperator: str = '\n',
+):
+ """
+ Convert a LLaVA-like dataset to the Data-Juicer format.
+
+ :param llava_ds_path: path to the input LLaVA-like dataset.
+ :param target_ds_path: path to store the converted dataset in Data-Juicer
+ format.
+ :param str_id: whether to convert all ids to str type. Default: True.
+ :param split_chunk: whether to split each round of (human, robot)
+ conversation pair into a single chunk. Default: False.
+ :param image_broadcast: whether to broadcast the image token to all
+ conversation rounds. If it's True, an image_special_token will be added
+ to the human question in each conversation round. Default: False.
+ :param image_broadcast_pos: the position to add the broadcast
+ image_special_token. Should be one of ["before", "after", "random",
+ "follow"], which means add this token before/after the human sentence,
+ or ranomly choose "before" or "after", or follow the position of the
+ first conversation round. Default: random.
+ :param eoc_special_token: the special token for "end of a chunk". It's used
+ to split conversation chunks explicitly. Default: <|__dj__eoc|> (from
+ Data-Juicer).
+ :param image_special_token: the special token for images. It's used to
+ locate the images in the conversation. In typical LLaVA-like datasets,
+ this token always be "". You can change it to align with your
+ own LLaVA-like datasets but should be careful of possible compatibility
+ problems that come from this change. Default: .
+ :param add_eoc_at_last: whether to add an extra eoc_special_token at the
+ end of text. Default: True.
+ :param sent_seperator: seperator to split different sentences. Default: \n.
+ """
+ # ----- Constant settings. Better not to change them. -----
+ text_key = 'text' # default key of field to store the sample text
+ image_key = 'images' # default key of field to store the image list
+ from_format = '[[%s]]: ' # default handle method for the conversation role
+ # ----- Constant settings. Better not to change them. -----
+
+ # check arguments
+ # check paths
+ if not os.path.exists(llava_ds_path):
+ raise FileNotFoundError(f'Input LLaVA dataset [{llava_ds_path}] can '
+ f'not be found.')
+ if not target_ds_path.endswith('.jsonl'):
+ raise ValueError(f'Only support "jsonl" target dataset file now.')
+ if os.path.dirname(target_ds_path) \
+ and not os.path.exists(os.path.dirname(target_ds_path)):
+ logger.info(f'Create directory [{os.path.dirname(target_ds_path)}] '
+ f'for the target dataset.')
+ os.makedirs(os.path.dirname(target_ds_path))
+ # check whether to split chunk and broadcast image token to each chunk
+ if image_broadcast:
+ if not split_chunk:
+ raise ValueError('Arg split_chunk should be True when opening '
+ 'image_broadcast.')
+ if image_broadcast_pos not in ['random', 'before', 'after', 'follow']:
+ raise ValueError(f'Arg image_broadcast_pos should be one of ['
+ f'"random", "before", "after", "follow"], but '
+ f'given [{image_broadcast_pos}]')
+ # check if the default image special token is changed
+ if image_special_token != '':
+ logger.warning(f'The image_special_token used in the original LLaVA '
+ f'dataset is "". It\'s better to align the this '
+ f'token. There might be some compatibility problem if '
+ f'you change it.')
+ # check whether to add the eoc special token at last
+ if not add_eoc_at_last:
+ logger.warning(f'You choose not to add special eoc token at the last, '
+ f'which might cause some compatibility problems for '
+ f'other type of datasets (e.g. OpenFlamingo).')
+
+ # load LLaVA dataset
+ logger.info(f'Loading original LLaVA dataset.')
+ llava_ds = json.load(open(llava_ds_path, 'r', encoding='utf-8'))
+ logger.info(f'Load [{len(llava_ds)}] samples.')
+
+ with jl.open(target_ds_path, 'w') as writer:
+ for sample in tqdm(llava_ds):
+ # id
+ id = sample['id']
+ if str_id:
+ id = str(id)
+
+ # images and text
+ image = sample.get('image', '')
+ if image == '':
+ logger.warning(f'No images in the sample with id [{id}], '
+ f'which means this sample is not a multimodal '
+ f'sample. You\'d better remove this sample '
+ f'before converting.')
+
+ conversations = sample['conversations']
+
+ # assume the input dataset always contains multimodal conversations
+ # and the conversations are always consists of (human, robot) pairs
+ if len(conversations) % 2 != 0:
+ raise ValueError(f'The conversations in the sample with id '
+ f'[{id}] contains unbalance (human, robot) '
+ f'conversation round (number of conversation '
+ f'is [{len(conversations)}]). Please check '
+ f'and fix the dataset and retry.')
+
+ # image list
+ images = []
+ # record the image token position in the first conversation round
+ image_token_pos_in_first_round = ''
+ # save the formatted conversations
+ formatted_conversations = []
+ # the number of conversation rounds
+ num_round = len(conversations) // 2
+ for i in range(num_round):
+ # get the human question and robot answer in this round
+ human_round = conversations[2 * i]
+ robot_round = conversations[2 * i + 1]
+
+ # get the role and sentence values
+ role_human = from_format % human_round['from']
+ sent_human = human_round['value']
+ role_robot = from_format % robot_round['from']
+ sent_robot = robot_round['value']
+
+ if image == '':
+ # not a multimodal sample, keep everything still
+ pass
+ elif i == 0:
+ # record the image token position in the first round
+ if sent_human.startswith(image_special_token):
+ image_token_pos_in_first_round = 'before'
+ elif sent_human.endswith(image_special_token):
+ image_token_pos_in_first_round = 'after'
+ else:
+ raise ValueError(
+ f'The position of image_special_token in the '
+ f'first round conversation of sample with id '
+ f'[{id}] is neither before nor after the text. '
+ f'The position might be wrong or there is no '
+ f'image_special_token in this sample. Please '
+ f'check and fix the dataset and retry.'
+ )
+ images.append(image)
+ else:
+ # whether broadcast image special token to following
+ # conversation rounds
+ if image_broadcast:
+ # broadcast image to each conversation round
+ if image_broadcast_pos == 'before':
+ sent_human = image_special_token + sent_seperator \
+ + sent_human
+ elif image_broadcast_pos == 'after':
+ sent_human += sent_seperator + image_special_token
+ elif image_broadcast_pos == 'random':
+ if random.random() < 0.5:
+ # before
+ sent_human = image_special_token \
+ + sent_seperator + sent_human
+ else:
+ # after
+ sent_human += sent_seperator \
+ + image_special_token
+ else:
+ # follow the first round conversation
+ if image_token_pos_in_first_round == 'before':
+ sent_human = image_special_token \
+ + sent_seperator + sent_human
+ else:
+ sent_human += sent_seperator \
+ + image_special_token
+ images.append(image)
+
+ # combine these texts together
+ new_sent = role_human + sent_human + sent_seperator \
+ + role_robot + sent_robot
+ formatted_conversations.append(new_sent)
+
+ join_sep = sent_seperator
+ if split_chunk:
+ # split (human, robot) pairs into several chunks
+ join_sep = f' {eoc_special_token} ' + join_sep
+ text = join_sep.join(formatted_conversations)
+ if add_eoc_at_last:
+ # add an extra eoc token after the whole sample text
+ text += f' {eoc_special_token}'
+
+ # get the new sample with Data-Juicer format
+ new_sample = {
+ 'id': id,
+ text_key: text,
+ image_key: images,
+ }
+ writer.write(new_sample)
+ logger.info(f'Store the target dataset into [{target_ds_path}].')
+
+
+if __name__ == '__main__':
+ fire.Fire(main)