From b4cef5d47467906b0e43066502268a6b485a71ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ce=20Ge=20=28=E6=88=88=E7=AD=96=29?= Date: Wed, 17 Jul 2024 20:32:30 +0800 Subject: [PATCH] Refactor OP & Dataset (#336) * modelscope-sora news (#323) * News/modelscope sora (#327) * modelscope-sora news * remove empower * debug for gpu rank for analyser (#329) * debug for gpu rank for analyser * spec_numprocs -> num_proc * Add more unittest (#304) * add unittest env with gpu * fix unittest yml * add environment for unittest * update workflow trigger * update install step * fix install command * update working dir * update container * update working dir * change working directory * change working directory * change working directory * change working directory * change unittest * use test tag * finish tag support * support run op with different executro * fix pre-commit * add hf mirror * add hf mirror * run all test in standalone mode by default * ignore image face ratio * update tags * add ray testcase * add ray test in workflow * update ray unittest workflow * delete old unittest --------- Co-authored-by: root * Add source tag (#317) * add source tag for some mapper op * fix no attribute 'current_tag' when executing local tests * move op process logic from executor to base op * fix typo * move export outside op * init refactor * update analyser * fix format * clean up * bring back batch mapper * Improve fault tolerance & Fix Ray executor * fix wrapper * fix batched filter * Remove use_actor as it is not compatible with the refactored OP clas, unless the dataset class is refactored * make wrappers work with unittests * Compatible with unit tests and works with ray * fix unittest * fix wrappers with ray, map, filter * unify unittests * wrap deduplicators * Compatible with non-batched calls * Class-level wrappers - compatible with dataset.filter - bring back nested wrappers * Instance-level wrappers * Refined instance-level wrappers - Remove incomplete dataset.filter wrappers - Simplify code - Stack wrappers * fix use_cuda * Refactor dataset (#348) * refactor dataset * update unittest with DJDataset * fix unittest * update ray data load * add test * ray read json * update docker image version * actor is no longer supported * Regress filter's stats export logic --------- Co-authored-by: BeachWang <1400012807@pku.edu.cn> Co-authored-by: Xuchen Pan <32844285+pan-x-c@users.noreply.github.com> Co-authored-by: chenhesen Co-authored-by: garyzhang99 --- .github/workflows/docker/docker-compose.yml | 65 +++++ .github/workflows/unit-test.yml | 87 +++--- README.md | 1 + README_ZH.md | 1 + data_juicer/core/analyser.py | 21 +- data_juicer/core/data.py | 68 ++++- data_juicer/core/executor.py | 109 +------- data_juicer/core/ray_data.py | 132 +++++++++ data_juicer/core/ray_executor.py | 150 +---------- data_juicer/ops/base_op.py | 251 +++++++++++++++--- data_juicer/ops/load.py | 3 + .../ops/mapper/audio_ffmpeg_wrapped_mapper.py | 11 + data_juicer/ops/mapper/image_blur_mapper.py | 25 +- .../image_captioning_from_gpt4v_mapper.py | 4 +- .../ops/mapper/image_captioning_mapper.py | 5 +- .../ops/mapper/image_diffusion_mapper.py | 4 +- .../ops/mapper/image_face_blur_mapper.py | 10 + data_juicer/ops/mapper/nlpaug_en_mapper.py | 3 +- data_juicer/ops/mapper/nlpcda_zh_mapper.py | 3 +- .../video_captioning_from_audio_mapper.py | 4 +- .../video_captioning_from_frames_mapper.py | 3 +- ...video_captioning_from_summarizer_mapper.py | 4 +- .../video_captioning_from_video_mapper.py | 3 +- .../ops/mapper/video_face_blur_mapper.py | 11 + .../ops/mapper/video_ffmpeg_wrapped_mapper.py | 11 + .../mapper/video_remove_watermark_mapper.py | 11 + .../video_resize_aspect_ratio_mapper.py | 11 + .../mapper/video_resize_resolution_mapper.py | 11 + .../mapper/video_split_by_duration_mapper.py | 13 +- .../mapper/video_split_by_key_frame_mapper.py | 12 +- .../ops/mapper/video_split_by_scene_mapper.py | 8 + data_juicer/utils/constant.py | 3 + data_juicer/utils/process_utils.py | 39 +-- data_juicer/utils/registry.py | 1 + data_juicer/utils/unittest_utils.py | 64 +++++ docs/DeveloperGuide.md | 5 +- docs/DeveloperGuide_ZH.md | 5 +- environments/dist_requires.txt | 2 +- tests/config/test_config_funcs.py | 24 +- tests/format/test_unify_format.py | 2 +- .../test_document_deduplicator.py | 2 +- .../test_document_minhash_deduplicator.py | 2 +- .../test_document_simhash_deduplicator.py | 2 +- .../deduplicator/test_image_deduplicator.py | 2 +- .../deduplicator/test_video_deduplicator.py | 2 +- tests/ops/filter/test_alphanumeric_filter.py | 29 +- .../ops/filter/test_audio_duration_filter.py | 48 ++-- tests/ops/filter/test_audio_nmf_snr_filter.py | 2 +- tests/ops/filter/test_audio_size_filter.py | 2 +- .../filter/test_average_line_length_filter.py | 2 +- .../test_character_repetition_filter.py | 2 +- tests/ops/filter/test_flagged_words_filter.py | 2 +- .../filter/test_image_aesthetics_filter.py | 2 +- .../filter/test_image_aspect_ratio_filter.py | 2 +- .../filter/test_image_face_ratio_filter.py | 5 +- tests/ops/filter/test_image_nsfw_filter.py | 2 +- tests/ops/filter/test_image_shape_filter.py | 2 +- tests/ops/filter/test_image_size_filter.py | 2 +- .../filter/test_image_text_matching_filter.py | 2 +- .../test_image_text_similarity_filter.py | 2 +- .../ops/filter/test_image_watermark_filter.py | 2 +- .../filter/test_language_id_score_filter.py | 2 +- .../filter/test_maximum_line_length_filter.py | 2 +- tests/ops/filter/test_perplexity_filter.py | 2 +- .../test_phrase_grounding_recall_filter.py | 2 +- .../filter/test_special_characters_filter.py | 2 +- .../ops/filter/test_specified_field_filter.py | 2 +- .../test_specified_numeric_field_filter.py | 2 +- tests/ops/filter/test_stop_words_filter.py | 2 +- tests/ops/filter/test_suffix_filter.py | 2 +- tests/ops/filter/test_text_action_filter.py | 2 +- .../test_text_entity_dependency_filter.py | 2 +- tests/ops/filter/test_text_length_filter.py | 2 +- tests/ops/filter/test_token_num_filter.py | 2 +- .../filter/test_video_aesthetics_filter.py | 2 +- .../filter/test_video_aspect_ratio_filter.py | 2 +- .../ops/filter/test_video_duration_filter.py | 2 +- ...est_video_frames_text_similarity_filter.py | 2 +- .../filter/test_video_motion_score_filter.py | 2 +- tests/ops/filter/test_video_nsfw_filter.py | 2 +- .../test_video_ocr_area_ratio_filter.py | 2 +- .../filter/test_video_resolution_filter.py | 2 +- .../test_video_tagging_from_frames_filter.py | 4 +- .../ops/filter/test_video_watermark_filter.py | 2 +- tests/ops/filter/test_word_num_filter.py | 2 +- .../ops/filter/test_word_repetition_filter.py | 2 +- .../test_audio_ffmpeg_wrapped_mapper.py | 2 +- tests/ops/mapper/test_image_blur_mapper.py | 2 +- .../mapper/test_image_captioning_mapper.py | 28 +- .../ops/mapper/test_image_diffusion_mapper.py | 16 +- .../ops/mapper/test_image_face_blur_mapper.py | 2 +- tests/ops/mapper/test_nlpaug_en_mapper.py | 4 +- tests/ops/mapper/test_nlpcda_zh_mapper.py | 4 +- ...test_video_captioning_from_audio_mapper.py | 16 +- ...video_captioning_from_summarizer_mapper.py | 6 +- .../ops/mapper/test_video_face_blur_mapper.py | 2 +- .../test_video_ffmpeg_wrapped_mapper.py | 2 +- .../test_video_remove_watermark_mapper.py | 2 +- .../test_video_resize_aspect_ratio_mapper.py | 2 +- .../test_video_resize_resolution_mapper.py | 2 +- .../test_video_split_by_duration_mapper.py | 9 +- .../test_video_split_by_key_frame_mapper.py | 55 +++- .../test_video_split_by_scene_mapper.py | 2 +- .../test_video_tagging_from_audio_mapper.py | 4 +- .../test_video_tagging_from_frames_mapper.py | 4 +- ...test_frequency_specified_field_selector.py | 2 +- .../test_topk_specified_field_selector.py | 2 +- tests/run.py | 66 ++--- 108 files changed, 1012 insertions(+), 592 deletions(-) create mode 100644 .github/workflows/docker/docker-compose.yml create mode 100644 data_juicer/core/ray_data.py diff --git a/.github/workflows/docker/docker-compose.yml b/.github/workflows/docker/docker-compose.yml new file mode 100644 index 000000000..eeba32206 --- /dev/null +++ b/.github/workflows/docker/docker-compose.yml @@ -0,0 +1,65 @@ +version: '3' +services: + ray-head: + image: data-juicer-unittest:0.2.1 + pull_policy: never + command: ray start --head --dashboard-host 0.0.0.0 --include-dashboard true --block + environment: + - HF_HOME=/data/huggingface + - HF_ENDPOINT=https://hf-mirror.com + - TORCH_HOME=/data/torch + - NLTK_DATA=/data/nltk + - DATA_JUICER_CACHE_HOME=/data/dj + - RAY_ADDRESS=auto + working_dir: /workspace + networks: + - ray-network + volumes: + - huggingface_cache:/data + - ../../..:/workspace + ports: + - "6379:6379" + - "8265:8265" + shm_size: "64G" + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0', '1'] + capabilities: [gpu] + + ray-worker: + image: data-juicer-unittest:0.2.1 + pull_policy: never + command: ray start --address=ray-head:6379 --block + environment: + - HF_HOME=/data/huggingface + - HF_ENDPOINT=https://hf-mirror.com + - TORCH_HOME=/data/torch + - NLTK_DATA=/data/nltk + - DATA_JUICER_CACHE_HOME=/data/dj + working_dir: /workspace + volumes: + - huggingface_cache:/data + - ../../..:/workspace + depends_on: + - ray-head + networks: + - ray-network + shm_size: "64G" + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['2', '3'] + capabilities: [gpu] + +networks: + ray-network: + driver: bridge + +volumes: + huggingface_cache: + external: true diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 4962d78dd..f292f6fdd 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -1,58 +1,63 @@ # This workflow will install Python dependencies, run tests and lint with a single version of Python # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python -name: Unit Test +name: unittest -on: [push, pull_request, workflow_dispatch] +on: + workflow_dispatch: + pull_request: + push: + branches: + - main permissions: contents: read -jobs: - build: - - runs-on: ubuntu-latest +env: + ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true +jobs: + unittest-single: + runs-on: [self-hosted, linux] + environment: Testing steps: - uses: actions/checkout@v3 - - name: Check disk space - run: | - df -h - - name: Set up Python 3.8 - uses: actions/setup-python@v3 with: - python-version: "3.8" - - name: Check disk space + path: dj-${{ github.run_id }} + + - name: Setup docker compose + working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - df -h - - name: Install dependencies + docker compose up -d + + - name: Install data-juicer + working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - sudo apt-get install ffmpeg - python -m pip install --upgrade pip - pip install -v -e .[all] - pip install -v -e .[sandbox] - - name: Increase swapfile + docker compose exec ray-head pip install -e .\[all\] + docker compose exec ray-worker pip install -e .\[all\] + + - name: Clean dataset cache + working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - df -h - free -h - sudo swapoff -a - sudo fallocate -l 12G /mnt/swapfile - sudo chmod 600 /mnt/swapfile - sudo mkswap /mnt/swapfile - sudo swapon /mnt/swapfile - sudo swapon --show - - name: Clean data-juicer assets and models after cached - uses: webiny/action-post-run@3.1.0 - with: - run: rm -rf ~/.cache/data_juicer - - name: Cache data-juicer assets and models - uses: actions/cache@v3 - with: - path: ~/.cache/data_juicer - key: dj-assets-models - - name: Check disk space + docker compose exec ray-head rm -rf /data/huggingface/dataset + + - name: Run unittest standalone + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + run: | + docker compose exec ray-head python tests/run.py --tag standalone + + - name: Run unittest ray + working-directory: dj-${{ github.run_id }}/.github/workflows/docker run: | - df -h - - name: Run the test + docker compose exec ray-head python tests/run.py --tag ray + + - name: Remove docker compose + working-directory: dj-${{ github.run_id }}/.github/workflows/docker + if: always() + run: | + docker compose down --remove-orphans + + - name: Cleanup workspace + if: always() run: | - python tests/run.py + rm -rf dj-${{ github.run_id }} diff --git a/README.md b/README.md index 59fddbd32..96c7181e1 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ or [DingDing group](https://qr.dingtalk.com/action/joingroup?spm=a2c22.12281976. ---- ## News +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-06-01] ModelScope-Sora "Data Directors" creative sprint—Our third data-centric LLM competition has kicked off! Please visit the competition's [official website](https://tianchi.aliyun.com/competition/entrance/532219) for more information. - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-03-07] We release **Data-Juicer [v0.2.0](https://github.com/alibaba/data-juicer/releases/tag/v0.2.0)** now! In this new version, we support more features for **multimodal data (including video now)**, and introduce **[DJ-SORA](docs/DJ_SORA.md)** to provide open large-scale, high-quality datasets for SORA-like models. - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-20] We have actively maintained an *awesome list of LLM-Data*, welcome to [visit](docs/awesome_llm_data.md) and contribute! diff --git a/README_ZH.md b/README_ZH.md index 1a19fd5c9..8abb1fa22 100644 --- a/README_ZH.md +++ b/README_ZH.md @@ -33,6 +33,7 @@ Data-Juicer(包含[DJ-SORA](docs/DJ_SORA_ZH.md))正在积极更新和维护 ---- ## 新消息 +- ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-06-01] ModelScope-Sora“数据导演”创意竞速——第三届Data-Juicer大模型数据挑战赛已经正式启动!立即访问[竞赛官网](https://tianchi.aliyun.com/competition/entrance/532219),了解赛事详情。 - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-03-07] 我们现在发布了 **Data-Juicer [v0.2.0](https://github.com/alibaba/data-juicer/releases/tag/v0.2.0)**! 在这个新版本中,我们支持了更多的 **多模态数据(包括视频)** 相关特性。我们还启动了 **[DJ-SORA](docs/DJ_SORA_ZH.md)** ,为SORA-like大模型构建开放的大规模高质量数据集! - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-20] 我们在积极维护一份关于LLM-Data的*精选列表*,欢迎[访问](docs/awesome_llm_data.md)并参与贡献! - ![new](https://img.alicdn.com/imgextra/i4/O1CN01kUiDtl1HVxN6G56vN_!!6000000000764-2-tps-43-19.png) [2024-02-05] 我们的论文被SIGMOD'24 industrial track接收! diff --git a/data_juicer/core/analyser.py b/data_juicer/core/analyser.py index 641471156..36e49c6a1 100644 --- a/data_juicer/core/analyser.py +++ b/data_juicer/core/analyser.py @@ -7,9 +7,7 @@ from data_juicer.format import load_formatter from data_juicer.ops import Filter, load_ops from data_juicer.utils import cache_utils -from data_juicer.utils.constant import Fields -from .data import add_same_content_to_new_column from .exporter import Exporter @@ -87,21 +85,12 @@ def run(self, load_data_np=None, skip_export=False): # 2. stats precompute only for filter ops logger.info('Computing the stats of dataset...') stats_collected = False - for op_cfg, op in zip(self.cfg.process, self.ops): - op_name = list(op_cfg.keys())[0] + for op in self.ops: if isinstance(op, Filter): - if Fields.stats not in dataset.features: - # only add stats when calling filter op - dataset = dataset.map(add_same_content_to_new_column, - fn_kwargs={ - 'new_column_name': Fields.stats, - 'initial_value': {} - }, - num_proc=self.cfg.np, - desc='Adding new column for stats') - dataset = dataset.map(op.compute_stats, - num_proc=self.cfg.np, - desc=op_name + '_compute_stats') + original_process = op.process + op.process = None + dataset = dataset.process(op) + op.process = original_process stats_collected = True if not stats_collected: logger.warning('No stats collected. Please add some Filter ops to ' diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index 326730c6b..14f5f507e 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import copy import inspect +from abc import ABC, abstractmethod from functools import wraps +from time import time from typing import Union from datasets import Dataset, DatasetDict, is_caching_enabled @@ -14,6 +18,21 @@ from data_juicer.utils.fingerprint_utils import generate_fingerprint +class DJDataset(ABC): + """Base dataset of DJ""" + + @abstractmethod + def process( + self, + operators, # TODO: add type hint + *, + exporter=None, + checkpointer=None, + tracer=None) -> DJDataset: + """process a list of operators on the dataset.""" + pass + + def wrap_func_with_nested_access(f): """ Before conducting actual function `f`, wrap its args and kargs into nested @@ -116,7 +135,7 @@ def map(self, **args): return super().map(**args) -class NestedDataset(Dataset): +class NestedDataset(Dataset, DJDataset): """Enhanced HuggingFace-Dataset for better usability and efficiency.""" def __init__(self, *args, **kargs): @@ -139,6 +158,37 @@ def __getitem__(self, key): res = super().__getitem__(key) return nested_obj_factory(res) + def process(self, + operator, + *, + exporter=None, + checkpointer=None, + tracer=None): + if operator is None: + return self + + if not isinstance(operator, list): + ops = [operator] + else: + ops = operator + + start = time() + tstart = start + dataset = self + for op in ops: + dataset = op(dataset, + exporter=exporter, + checkpointer=checkpointer, + tracer=tracer) + end = time() + logger.info( + f'OP [{op._name}] Done in {"%.3f" % (end - start)}(s). ' + f'Left {len(dataset)} samples.') + start = end + tend = time() + logger.info(f'All OPs are done in {"%.3f" % (tend - tstart)}(s).') + return dataset + def map(self, *args, **kargs): """Override the map func, which is called by most common operations, such that the processed samples can be accessed by nested manner.""" @@ -158,16 +208,16 @@ def map(self, *args, **kargs): kargs['function']) called_func = kargs['function'] - # For wrapped function, try to get its original unwrapped method - while hasattr(called_func, '__wrapped__'): + # For wrapped function, try to get its unwrapped (bound) method + while not inspect.ismethod(called_func) and hasattr( + called_func, '__wrapped__'): called_func = called_func.__wrapped__ - # Does the called function belong to a batched OP? - if inspect.ismethod(called_func) \ - and 'is_batched_op' in dir(called_func.__self__) \ - and callable(getattr(called_func.__self__, 'is_batched_op')) \ - and called_func.__self__.is_batched_op(): + + # Batched is always required for fault tolerance + if inspect.ismethod(called_func): kargs['batched'] = True - kargs['batch_size'] = 1 + kargs['batch_size'] = kargs.pop( + 'batch_size', 1) if called_func.__self__.is_batched_op() else 1 if 'new_fingerprint' not in kargs or kargs['new_fingerprint'] is None: new_fingerprint = generate_fingerprint(self, *args, **kargs) diff --git a/data_juicer/core/executor.py b/data_juicer/core/executor.py index 9ec12a35e..978038862 100644 --- a/data_juicer/core/executor.py +++ b/data_juicer/core/executor.py @@ -1,25 +1,20 @@ import os -from time import time +import traceback from loguru import logger -from data_juicer import use_cuda from data_juicer.config import init_configs from data_juicer.core.data import Dataset from data_juicer.format.load import load_formatter from data_juicer.format.mixture_formatter import MixtureFormatter -from data_juicer.ops import (OPERATORS, Deduplicator, Filter, Mapper, Selector, - load_ops) +from data_juicer.ops import OPERATORS, load_ops from data_juicer.utils import cache_utils from data_juicer.utils.ckpt_utils import CheckpointManager -from data_juicer.utils.constant import Fields -from data_juicer.utils.process_utils import calculate_np from ..ops.selector.frequency_specified_field_selector import \ FrequencySpecifiedFieldSelector from ..ops.selector.topk_specified_field_selector import \ TopkSpecifiedFieldSelector -from .data import add_same_content_to_new_column from .exporter import Exporter from .tracer import Tracer @@ -43,6 +38,8 @@ def __init__(self, cfg=None): self.work_dir = self.cfg.work_dir self.ops = None + self.tracer = None + self.ckpt_manager = None # only enable it when using cache if self.cfg.use_cache: @@ -164,99 +161,10 @@ def run(self, load_data_np=None): # - If tracer is open, trace each op after it's processed # - If checkpoint is open, clean the cache files after each process logger.info('Processing data...') - start = time() - tstart = start - for op_cfg, op in zip(self.process_list, self.ops): - op_name, op_args = list(op_cfg.items())[0] - prev = dataset # record last dataset - with_rank = use_cuda() and op._accelerator == 'cuda' - if op.spec_numprocs != 0: - op_proc = op.spec_numprocs - logger.info(f'Op [{op_name}] running with sepcified ' - f'number of procs:{op.spec_numprocs}') - else: - op_proc = calculate_np(self.cfg.np, op, op_name) - try: - if isinstance(op, Mapper): - tmp = dataset.map(function=op.process, - num_proc=op_proc, - with_rank=with_rank, - desc=op_name + '_process') - if self.open_tracer and \ - op_name in self.op_list_to_trace: - if op.is_batched_op(): - self.tracer.trace_batch_mapper( - op_name, dataset, tmp, op.text_key) - else: - self.tracer.trace_mapper(op_name, dataset, tmp, - op.text_key) - elif isinstance(op, Filter): - if Fields.stats not in dataset.features: - # only add stats when calling filter op - dataset = dataset.map( - add_same_content_to_new_column, - fn_kwargs={ - 'new_column_name': Fields.stats, - 'initial_value': {} - }, - num_proc=self.cfg.np, - desc='Adding new column for stats') - if self.cfg.use_checkpoint: - prev = dataset - dataset = dataset.map(op.compute_stats, - num_proc=op_proc, - with_rank=with_rank, - desc=op_name + '_compute_stats') - if self.cfg.use_checkpoint: - prev = dataset - if op.stats_export_path is not None: - self.exporter.export_compute_stats( - dataset, op.stats_export_path) - tmp = dataset.filter(op.process, - num_proc=self.cfg.np, - desc=op_name + '_process') - if self.open_tracer and op_name in self.op_list_to_trace: - self.tracer.trace_filter(op_name, dataset, tmp) - elif isinstance(op, Selector): - tmp = op.process(dataset) - if self.open_tracer and op_name in self.op_list_to_trace: - self.tracer.trace_filter(op_name, dataset, tmp) - elif isinstance(op, Deduplicator): - dataset = dataset.map(op.compute_hash, - num_proc=op_proc, - with_rank=with_rank, - desc=op_name + '_compute_hash') - if self.cfg.use_checkpoint: - prev = dataset - tmp, dup_pairs = op.process( - dataset, self.tracer.show_num if self.open_tracer - and op_name in self.op_list_to_trace else 0) - if self.open_tracer and op_name in self.op_list_to_trace: - self.tracer.trace_deduplicator(op_name, dup_pairs) - else: - raise NotImplementedError - dataset = tmp - except: # noqa: E722 - logger.error(f'An error occurred during Op [{op_name}].') - import traceback - traceback.print_exc() - if self.cfg.use_checkpoint: - logger.info('Writing checkpoint of dataset processed by ' - 'last op...') - prev.cleanup_cache_files() - self.ckpt_manager.save_ckpt(prev) - exit(1) - - # clean up cache files and record processed ops - if self.cfg.use_checkpoint: - self.ckpt_manager.record(op_name, op_args) - - end = time() - logger.info(f'Op [{op_name}] Done in {"%.3f" % (end - start)}(s). ' - f'Left {len(dataset)} samples.') - start = end - tend = time() - logger.info(f'All Ops are done in {"%.3f" % (tend - tstart)}(s).') + dataset = dataset.process(self.ops, + exporter=self.exporter, + checkpointer=self.ckpt_manager, + tracer=self.tracer) # 4. data export logger.info('Exporting dataset to disk...') @@ -265,7 +173,6 @@ def run(self, load_data_np=None): except: # noqa: E722 logger.error('An error occurred during exporting the processed ' 'dataset.') - import traceback traceback.print_exc() if self.cfg.use_checkpoint: logger.info('Writing checkpoint of dataset processed by ' diff --git a/data_juicer/core/ray_data.py b/data_juicer/core/ray_data.py new file mode 100644 index 000000000..e9d90ab83 --- /dev/null +++ b/data_juicer/core/ray_data.py @@ -0,0 +1,132 @@ +import os + +import pyarrow as pa +from loguru import logger + +from data_juicer import cuda_device_count, use_cuda +from data_juicer.core.data import DJDataset +from data_juicer.ops import Filter, Mapper +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields +from data_juicer.utils.process_utils import calculate_np + +with AvailabilityChecking(['ray'], requires_type='dist'): + from ray.data import Dataset + + +def is_valid_path(item, dataset_dir): + full_path = os.path.abspath(os.path.join(dataset_dir, item)) + return os.path.exists(full_path) + + +def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys): + for key in path_keys: + if key not in dict_with_paths: + continue + if isinstance(dict_with_paths[key], list): + dict_with_paths[key] = [ + os.path.abspath(os.path.join(dataset_dir, item)) + if isinstance(item, str) and is_valid_path(dataset_dir, item) + else item for item in dict_with_paths[key] + ] + elif isinstance(dict_with_paths[key], str): + dict_with_paths[key] = os.path.abspath( + os.path.join(dataset_dir, + dict_with_paths[key])) if is_valid_path( + dict_with_paths[key], + dataset_dir) else dict_with_paths[key] + return dict_with_paths + + +# TODO: check path for nestdataset +def set_dataset_to_absolute_path(dataset, dataset_path, cfg): + """ + Set all the path in input data to absolute path. + Checks dataset_dir and project_dir for valid paths. + """ + if not (cfg.video_key in dataset.columns() or cfg.image_key + in dataset.columns() or cfg.audio_key in dataset.columns()): + return dataset + dataset_dir = os.path.dirname(dataset_path) + dataset = dataset.map(lambda item: convert_to_absolute_paths( + item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key])) + logger.info(f"transfer {dataset.count()} sample's paths") + return dataset + + +def preprocess_dataset(dataset: Dataset, dataset_path, cfg) -> Dataset: + if dataset_path: + dataset = set_dataset_to_absolute_path(dataset, dataset_path, cfg) + if Fields.stats not in dataset.columns(fetch_if_missing=False): + logger.info(f'columns {dataset.columns(fetch_if_missing=False)}') + + def process_batch_arrow(table: pa.Table) -> pa.Table: + new_column_data = [{} for _ in range(len(table))] + new_talbe = table.append_column(Fields.stats, [new_column_data]) + return new_talbe + + dataset = dataset.map_batches(process_batch_arrow, + batch_format='pyarrow') + return dataset + + +def get_num_gpus(op, op_proc): + if not use_cuda() or not op._accelerator == 'cuda': + return 0 + proc_per_gpu = op_proc / cuda_device_count() + return 1.0 / proc_per_gpu + + +class RayDataset(DJDataset): + + def __init__(self, + dataset: Dataset, + dataset_path: str = None, + cfg=None) -> None: + self.data = preprocess_dataset(dataset, dataset_path, cfg) + self.num_proc = None + if cfg: + self.num_proc = cfg.np + + def process(self, + operators, + *, + exporter=None, + checkpointer=None, + tracer=None) -> DJDataset: + if operators is None: + return self + elif not isinstance(operators, list): + operators = [operators] + for op in operators: + self._run_single_op(op) + return self + + def _run_single_op(self, op): + op_proc = calculate_np(op._name, op.mem_required, op.cpu_required, + self.num_proc, op.use_cuda()) + num_gpus = get_num_gpus(op, op_proc) + try: + if isinstance(op, Mapper): + self.data = self.data.map_batches(op.process, + batch_size=1, + batch_format='pyarrow', + num_gpus=num_gpus) + elif isinstance(op, Filter): + self.data = self.data.map_batches(op.compute_stats, + batch_size=1, + batch_format='pyarrow', + num_gpus=num_gpus) + if op.stats_export_path is not None: + self.data.write_json(op.stats_export_path, + force_ascii=False) + self.data = self.data.filter(op.process) + else: + logger.error( + 'Ray executor only support Filter and Mapper OPs for now') + raise NotImplementedError + except: # noqa: E722 + logger.error(f'An error occurred during Op [{op._name}].') + import traceback + traceback.print_exc() + exit(1) diff --git a/data_juicer/core/ray_executor.py b/data_juicer/core/ray_executor.py index d42d72f95..983203e52 100644 --- a/data_juicer/core/ray_executor.py +++ b/data_juicer/core/ray_executor.py @@ -1,71 +1,15 @@ -import os import time -from functools import partial -import pandas as pd -import pyarrow as pa from loguru import logger -from data_juicer import cuda_device_count, use_cuda from data_juicer.config import init_configs -from data_juicer.ops import Filter, Mapper, load_ops +from data_juicer.core.ray_data import RayDataset +from data_juicer.ops import load_ops from data_juicer.utils.availability_utils import AvailabilityChecking -from data_juicer.utils.constant import Fields -from data_juicer.utils.process_utils import calculate_np with AvailabilityChecking(['ray'], requires_type='dist'): import ray import ray.data as rd - from ray.data import ActorPoolStrategy - -from data_juicer.ops.base_op import OPERATORS - - -def is_valid_path(item, dataset_dir): - full_path = os.path.abspath(os.path.join(dataset_dir, item)) - return os.path.exists(full_path) - - -def convert_to_absolute_paths(dict_with_paths, dataset_dir, path_keys): - for key in path_keys: - if key not in dict_with_paths: - continue - if isinstance(dict_with_paths[key], list): - dict_with_paths[key] = [ - os.path.abspath(os.path.join(dataset_dir, item)) - if isinstance(item, str) and is_valid_path(dataset_dir, item) - else item for item in dict_with_paths[key] - ] - elif isinstance(dict_with_paths[key], str): - dict_with_paths[key] = os.path.abspath( - os.path.join(dataset_dir, - dict_with_paths[key])) if is_valid_path( - dict_with_paths[key], - dataset_dir) else dict_with_paths[key] - return dict_with_paths - - -def set_dataset_to_absolute_path(dataset, dataset_path, cfg): - """ - Set all the path in input data to absolute path. - Checks dataset_dir and project_dir for valid paths. - """ - if not (cfg.video_key in dataset.columns() or cfg.image_key - in dataset.columns() or cfg.audio_key in dataset.columns()): - return dataset - dataset_dir = os.path.dirname(dataset_path) - dataset = dataset.map(lambda item: convert_to_absolute_paths( - item, dataset_dir, [cfg.video_key, cfg.image_key, cfg.audio_key])) - logger.info(f"transfer {dataset.count()} sample's paths") - return dataset - - -def ray_batch_mapper_wrapper(samples, fn): - samples = samples.to_pandas() - res = fn(samples) - if not isinstance(res, pd.DataFrame): - res = pd.DataFrame(res) - return pa.Table.from_pandas(res) class RayExecutor: @@ -96,12 +40,6 @@ def __init__(self, cfg=None): ray.init(self.cfg.ray_address) self.process_list = self.cfg.process - def get_num_gpus(self, op, op_proc): - if not use_cuda() or not op._accelerator == 'cuda': - return 0 - proc_per_gpu = op_proc / cuda_device_count() - return 1.0 / proc_per_gpu - def run(self, load_data_np=None): """ Running the dataset process pipeline. @@ -114,98 +52,20 @@ def run(self, load_data_np=None): dataset = rd.read_json(self.cfg.dataset_path) # convert all the path in dataset to absolute path - dataset = set_dataset_to_absolute_path(dataset, self.cfg.dataset_path, - self.cfg) - logger.info('Dataset columns:', dataset.columns()) + dataset = RayDataset(dataset, self.cfg.dataset_path, self.cfg) # 2. extract processes logger.info('Preparing process operators...') self.process_list, self.ops = load_ops(self.cfg.process, self.cfg.op_fusion) # 3. data process - # - If tracer is open, trace each op after it's processed - # - If checkpoint is open, clean the cache files after each process - if Fields.stats not in dataset.columns(fetch_if_missing=False): - logger.info(f'columns {dataset.columns(fetch_if_missing=False)}') - - def process_batch_arrow(table: pa.Table) -> pa.Table: - new_column_data = [{} for _ in range(len(table))] - new_talbe = table.append_column(Fields.stats, - [new_column_data]) - return new_talbe - - dataset = dataset.map_batches(process_batch_arrow, - batch_format='pyarrow') - logger.info('Processing data...') tstart = time.time() - for op_cfg, op in zip(self.process_list, self.ops): - op_name, op_args = list(op_cfg.items())[0] - op_cls = OPERATORS.modules[op_name] - op_proc = calculate_np(self.cfg.np, op, op_name) - num_gpus = self.get_num_gpus(op, op_proc) - use_actor = op.use_actor() or num_gpus - try: - if isinstance(op, Mapper): - if op.is_batched_op(): - if use_actor: - dataset = dataset.map_batches( - op_cls, - compute=ActorPoolStrategy(), - concurrency=op_proc, - fn_constructor_kwargs=op_args, - batch_format='pyarrow', - num_gpus=num_gpus, - batch_size=1) - # The batch size here is same as in data.py - else: - dataset = dataset.map_batches( - partial(ray_batch_mapper_wrapper, - fn=op.process), - batch_format='pyarrow', - num_gpus=num_gpus, - batch_size=1) - # The batch size here is same as in data.py - else: - if use_actor: - dataset = dataset.map( - op_cls, - compute=ActorPoolStrategy(), - concurrency=op_proc, - fn_constructor_kwargs=op_args, - num_gpus=num_gpus) - else: - dataset = dataset.map(op.process, - num_gpus=num_gpus) - - elif isinstance(op, Filter): - if use_actor: - dataset = dataset.map(op_cls, - compute=ActorPoolStrategy(), - concurrency=op_proc, - fn_constructor_kwargs=op_args, - num_gpus=num_gpus) - else: - dataset = dataset.map(op.compute_stats, - num_gpus=num_gpus) - if op.stats_export_path is not None: - dataset.write_json(op.stats_export_path, - force_ascii=False) - dataset = dataset.filter(op.process) - else: - logger.error( - 'Ray executor only support Filter and Mapper OPs for ' - 'now') - raise NotImplementedError - except: # noqa: E722 - logger.error(f'An error occurred during Op [{op_name}].') - import traceback - traceback.print_exc() - exit(1) + dataset.process(self.ops) # 4. data export logger.info('Exporting dataset to disk...') - dataset.write_json(self.cfg.export_path, force_ascii=False) + dataset.data.write_json(self.cfg.export_path, force_ascii=False) tend = time.time() logger.info(f'All Ops are done in {"%.3f" % (tend - tstart)}(s).') return dataset diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index b5b2e79d9..29e295e0f 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -1,16 +1,116 @@ import copy +import traceback +from functools import wraps -import pandas as pd import pyarrow as pa +from loguru import logger +from data_juicer import use_cuda +from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import size_to_bytes +from data_juicer.utils.process_utils import calculate_np from data_juicer.utils.registry import Registry OPERATORS = Registry('Operators') +def convert_list_dict_to_dict_list(samples): + # reconstruct samples from "list of dicts" to "dict of lists" + keys = samples[0].keys() + res_samples = {} + for key in keys: + res_samples[key] = [s[key] for s in samples] + return res_samples + + +def convert_dict_list_to_list_dict(samples): + # reconstruct samples from "dict of lists" to "list of dicts" + reconstructed_samples = [] + keys = list(samples.keys()) + # take any key, since they should be of same length + for i in range(len(samples[keys[0]])): + reconstructed_samples.append({key: samples[key][i] for key in samples}) + return reconstructed_samples + + +def convert_arrow_to_python(method): + + @wraps(method) + def wrapper(sample, *args, **kwargs): + if isinstance(sample, pa.Table): + sample = sample.to_pydict() + return method(sample, *args, **kwargs) + + return wrapper + + +def catch_map_batches_exception(method): + """ + For batched-map sample-level fault tolerance. + """ + + @wraps(method) + @convert_arrow_to_python + def wrapper(samples, *args, **kwargs): + try: + return method(samples, *args, **kwargs) + except Exception as e: + from loguru import logger + logger.error( + f'An error occurred in mapper operation when processing ' + f'samples {samples}, {type(e)}: {e}') + ret = {key: [] for key in samples.keys()} + ret[Fields.stats] = [] + ret[Fields.source_file] = [] + return ret + + return wrapper + + +def catch_map_single_exception(method): + """ + For single-map sample-level fault tolerance. + The input sample is expected batch_size = 1. + """ + + def is_batched(sample): + val_iter = iter(sample.values()) + first_val = next(val_iter) + if not isinstance(first_val, list): + return False + first_len = len(first_val) + return all( + isinstance(val, list) and len(val) == first_len + for val in val_iter) + + @wraps(method) + @convert_arrow_to_python + def wrapper(sample, *args, **kwargs): + if is_batched(sample): + try: + sample = convert_dict_list_to_list_dict(sample)[0] + res_sample = method(sample, *args, **kwargs) + return convert_list_dict_to_dict_list([res_sample]) + except Exception as e: + from loguru import logger + logger.error( + f'An error occurred in mapper operation when processing ' + f'sample {sample}, {type(e)}: {e}') + ret = {key: [] for key in sample.keys()} + ret[Fields.stats] = [] + ret[Fields.source_file] = [] + return ret + else: + # without fault tolerance + return method(sample, *args, **kwargs) + + return wrapper + + class OP: + _batched_op = False + def __init__(self, *args, **kwargs): """ Base class of operators. @@ -34,23 +134,55 @@ def __init__(self, *args, **kwargs): self._accelerator = kwargs.get('accelerator', 'cpu') # parameters to determind the number of procs for this op - self.spec_numprocs = kwargs.get('spec_numprocs', 0) + self.num_proc = kwargs.get('num_proc', 0) self.cpu_required = kwargs.get('cpu_required', 1) self.mem_required = kwargs.get('mem_required', 0) if isinstance(self.mem_required, str): self.mem_required = size_to_bytes(self.mem_required) / 1024**3 - # whether to use actor mode in ray - self._use_actor = kwargs.get('use_actor', False) - + # nested wrappers from data_juicer.core.data import wrap_func_with_nested_access - self.process = wrap_func_with_nested_access(self.process) + for name in ['process', 'compute_stats', 'compute_hash']: + method = getattr(self, name, None) + if method and callable(method): + setattr(self, f'_{name}', method) + method = wrap_func_with_nested_access(method) + setattr(self, name, method) + + def __call__(self, + dataset, + *, + exporter=None, + checkpointer=None, + tracer=None): + try: + dataset = self.run(dataset, exporter=exporter, tracer=tracer) + if checkpointer: + checkpointer.record(self._name, self._process_kwargs) + return dataset + except: # noqa: E722 + logger.error(f'An error occurred during Op [{self._name}].') + traceback.print_exc() + if checkpointer: + logger.info('Writing checkpoint of dataset processed by ' + 'last op...') + dataset.cleanup_cache_files() + checkpointer.save_ckpt(dataset) + exit(1) + + @classmethod + def is_batched_op(cls): + return cls._batched_op def process(self, *args, **kwargs): raise NotImplementedError - def use_actor(self): - return self._use_actor + def use_cuda(self): + return self._accelerator == 'cuda' and use_cuda() + + def runtime_np(self): + return calculate_np(self._name, self.mem_required, self.cpu_required, + self.num_proc, self.use_cuda()) def remove_extra_parameters(self, param_dict, keys=None): """ @@ -79,14 +211,6 @@ def add_parameters(self, init_parameter_dict, **extra_param_dict): return related_parameters -def ray_batch_mapper_wrapper(samples, fn): - samples = samples.to_pandas() - res = fn(samples) - if not isinstance(res, pd.DataFrame): - res = pd.DataFrame(res) - return pa.Table.from_pandas(res) - - class Mapper(OP): def __init__(self, *args, **kwargs): @@ -104,8 +228,11 @@ def __init__(self, *args, **kwargs): """ super(Mapper, self).__init__(*args, **kwargs) - # In default, it's a normal OP instead of batched OP - self._batched_op = kwargs.get('batched_op', False) + # runtime wrappers + if self.is_batched_op(): + self.process = catch_map_batches_exception(self.process) + else: + self.process = catch_map_single_exception(self.process) def process(self, sample): """ @@ -116,22 +243,17 @@ def process(self, sample): """ raise NotImplementedError - def is_batched_op(self): - return self._batched_op - - def __call__(self, sample): - """ - Make the class callable to enable ray actor usage - """ - if self.is_batched_op(): - # same logic as ray_batch_mapper_wrapper - samples = sample.to_pandas() - res = self.process(samples) - if not isinstance(res, pd.DataFrame): - res = pd.DataFrame(res) - return pa.Table.from_pandas(res) - else: - return self.process(sample) + def run(self, dataset, *, exporter=None, tracer=None): + new_dataset = dataset.map( + self.process, + num_proc=self.runtime_np(), + with_rank=self.use_cuda(), + desc=self._name + '_process', + ) + if tracer: + tracer.trace_mapper(self._name, dataset, new_dataset, + self.text_key) + return new_dataset class Filter(OP): @@ -150,11 +272,15 @@ def __init__(self, *args, **kwargs): to be processed """ super(Filter, self).__init__(*args, **kwargs) - - from data_juicer.core.data import wrap_func_with_nested_access - self.compute_stats = wrap_func_with_nested_access(self.compute_stats) self.stats_export_path = kwargs.get('stats_export_path', None) + # runtime wrappers + if self.is_batched_op(): + self.compute_stats = catch_map_batches_exception( + self.compute_stats) + else: + self.compute_stats = catch_map_single_exception(self.compute_stats) + def compute_stats(self, sample, context=False): """ Compute stats for the sample which is used as a metric to decide @@ -176,11 +302,28 @@ def process(self, sample): """ raise NotImplementedError - def __call__(self, sample): - """ - Make the class callable to enable ray actor usage - """ - return self.compute_stats(sample) + def run(self, dataset, *, exporter=None, tracer=None): + if Fields.stats not in dataset.features: + from data_juicer.core.data import add_same_content_to_new_column + dataset = dataset.map(add_same_content_to_new_column, + fn_kwargs={ + 'new_column_name': Fields.stats, + 'initial_value': {} + }, + num_proc=self.runtime_np(), + desc='Adding new column for stats') + dataset = dataset.map(self.compute_stats, + num_proc=self.runtime_np(), + with_rank=self.use_cuda(), + desc=self._name + '_compute_stats') + if self.stats_export_path is not None: + exporter.export_compute_stats(dataset, self.stats_export_path) + new_dataset = dataset.filter(self.process, + num_proc=self.runtime_np(), + desc=self._name + '_process') + if tracer: + tracer.trace_filter(self._name, dataset, new_dataset) + return new_dataset class Deduplicator(OP): @@ -200,8 +343,11 @@ def __init__(self, *args, **kwargs): """ super(Deduplicator, self).__init__(*args, **kwargs) - from data_juicer.core.data import wrap_func_with_nested_access - self.compute_hash = wrap_func_with_nested_access(self.compute_hash) + # runtime wrappers + if self.is_batched_op(): + self.compute_hash = catch_map_batches_exception(self.compute_hash) + else: + self.compute_hash = catch_map_single_exception(self.compute_hash) def compute_hash(self, sample): """ @@ -223,6 +369,17 @@ def process(self, dataset, show_num=0): """ raise NotImplementedError + def run(self, dataset, *, exporter=None, tracer=None): + dataset = dataset.map(self.compute_hash, + num_proc=self.runtime_np(), + with_rank=self.use_cuda(), + desc=self._name + '_compute_hash') + show_num = tracer.show_num if tracer else 0 + new_dataset, dup_pairs = self.process(dataset, show_num) + if tracer: + tracer.trace_deduplicator(self._name, dup_pairs) + return new_dataset + class Selector(OP): @@ -249,3 +406,9 @@ def process(self, dataset): :return: selected dataset. """ raise NotImplementedError + + def run(self, dataset, *, exporter=None, tracer=None): + new_dataset = self.process(dataset) + if tracer: + tracer.trace_filter(self._name, dataset, new_dataset) + return new_dataset diff --git a/data_juicer/ops/load.py b/data_juicer/ops/load.py index 4e9aa248a..60aac3ec4 100644 --- a/data_juicer/ops/load.py +++ b/data_juicer/ops/load.py @@ -32,4 +32,7 @@ def load_ops(process_list, op_fusion=False): if op_fusion: new_process_list, ops = fuse_operators(new_process_list, ops) + for process, op in zip(new_process_list, ops): + op._process_kwargs = process + return new_process_list, ops diff --git a/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py b/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py index 730098e39..0c5341662 100644 --- a/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py +++ b/data_juicer/ops/mapper/audio_ffmpeg_wrapped_mapper.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints @@ -50,8 +51,12 @@ def __init__( def process(self, sample): # there is no audio in this sample if self.audio_key not in sample or not sample[self.audio_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.audio_key] + if self.filter_name is None: return sample @@ -71,5 +76,11 @@ def process(self, sample): overwrite_output=self.overwrite_output) processed[audio_key] = output_key + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(loaded_audio_keys): + if sample[Fields.source_file][i] != value: + if processed[value] != value: + sample[Fields.source_file][i] = value + sample[self.audio_key] = [processed[key] for key in loaded_audio_keys] return sample diff --git a/data_juicer/ops/mapper/image_blur_mapper.py b/data_juicer/ops/mapper/image_blur_mapper.py index 71bb03f16..536f952a9 100644 --- a/data_juicer/ops/mapper/image_blur_mapper.py +++ b/data_juicer/ops/mapper/image_blur_mapper.py @@ -56,28 +56,41 @@ def __init__(self, def process(self, sample, context=False): # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.image_key] + # load images loaded_image_keys = sample[self.image_key] sample, images = load_data_with_context(sample, context, loaded_image_keys, load_image) + processed = {} + for image_key in loaded_image_keys: + if image_key in processed: + continue - for index, value in enumerate(loaded_image_keys): if self.p < np.random.rand(): - continue + processed[image_key] = image_key else: - blured_image_key = transfer_filename(value, OP_NAME, + blured_image_key = transfer_filename(image_key, OP_NAME, **self._init_parameters) if not os.path.exists( blured_image_key) or blured_image_key not in images: - blured_image = images[value].convert('RGB').filter( + blured_image = images[image_key].convert('RGB').filter( self.blur) images[blured_image_key] = blured_image blured_image.save(blured_image_key) if context: sample[Fields.context][blured_image_key] = blured_image - loaded_image_keys[index] = blured_image_key + processed[image_key] = blured_image_key + + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(loaded_image_keys): + if sample[Fields.source_file][i] != value: + if processed[value] != value: + sample[Fields.source_file][i] = value - sample[self.image_key] = loaded_image_keys + sample[self.image_key] = [processed[key] for key in loaded_image_keys] return sample diff --git a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py index 8b58f1e3a..76cfbfae0 100644 --- a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py @@ -98,6 +98,8 @@ class ImageCaptioningFromGPT4VMapper(Mapper): """Mapper to generate samples whose texts are generated based on gpt-4-visison and the image.""" + _batched_op = True + def __init__(self, mode: str = 'description', api_key: str = '', @@ -143,7 +145,7 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) - self._batched_op = True + if mode not in ['resoning', 'description', 'conversation', 'custom']: raise ValueError( f'Mode [{mode}] is not supported. ' diff --git a/data_juicer/ops/mapper/image_captioning_mapper.py b/data_juicer/ops/mapper/image_captioning_mapper.py index 5a678ad07..9b985f129 100644 --- a/data_juicer/ops/mapper/image_captioning_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_mapper.py @@ -34,6 +34,8 @@ class ImageCaptioningMapper(Mapper): """Mapper to generate samples whose captions are generated based on another model and the figure.""" + _batched_op = True + def __init__(self, hf_img2seq='Salesforce/blip2-opt-2.7b', caption_num: PositiveInt = 1, @@ -84,7 +86,7 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) - self._batched_op = True + if keep_candidate_mode not in [ 'random_any', 'similar_one_simhash', 'all' ]: @@ -102,7 +104,6 @@ def __init__(self, self.prompt = prompt self.prompt_key = prompt_key self.extra_args = kwargs - if keep_candidate_mode in ['random_any', 'similar_one_simhash']: self.num_newly_generated_samples = 1 elif keep_candidate_mode in ['all']: diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py index f7fdab0c3..635712a07 100644 --- a/data_juicer/ops/mapper/image_diffusion_mapper.py +++ b/data_juicer/ops/mapper/image_diffusion_mapper.py @@ -33,6 +33,8 @@ class ImageDiffusionMapper(Mapper): Generate image by diffusion model """ + _batched_op = True + def __init__(self, hf_diffusion: str = 'CompVis/stable-diffusion-v1-4', torch_dtype: str = 'fp32', @@ -97,7 +99,6 @@ def __init__(self, """ super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - self._batched_op = True self._accelerator = 'cuda' self.strength = strength self.guidance_scale = guidance_scale @@ -111,7 +112,6 @@ def __init__(self, hf_img2seq=hf_img2seq, keep_original_sample=False, prompt=self.prompt) - self.model_key = prepare_model( model_type='diffusion', pretrained_model_name_or_path=hf_diffusion, diff --git a/data_juicer/ops/mapper/image_face_blur_mapper.py b/data_juicer/ops/mapper/image_face_blur_mapper.py index 5e048b1c3..3c2c889f1 100644 --- a/data_juicer/ops/mapper/image_face_blur_mapper.py +++ b/data_juicer/ops/mapper/image_face_blur_mapper.py @@ -69,8 +69,12 @@ def __init__(self, def process(self, sample, context=False): # there is no image in this sample if self.image_key not in sample or not sample[self.image_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.image_key] + # load images loaded_image_keys = sample[self.image_key] sample, images = load_data_with_context(sample, context, @@ -108,6 +112,12 @@ def process(self, sample, context=False): else: key_mapping[key] = key + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(loaded_image_keys): + if sample[Fields.source_file][i] != value: + if key_mapping[value] != value: + sample[Fields.source_file][i] = value + sample[self.image_key] = [ key_mapping[key] for key in loaded_image_keys ] diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index a721cf2b3..581296b6a 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -19,6 +19,8 @@ class NlpaugEnMapper(Mapper): """Mapper to simply augment samples in English based on nlpaug library.""" + _batched_op = True + def __init__(self, sequential: bool = False, aug_num: int = 1, @@ -84,7 +86,6 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) - self._batched_op = True # this is a batched OP self.aug_num = aug_num if aug_num >= 10: diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index 262d90782..4c7bdefe3 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -17,6 +17,8 @@ class NlpcdaZhMapper(Mapper): """Mapper to simply augment samples in Chinese based on nlpcda library.""" + _batched_op = True + def __init__(self, sequential: bool = False, aug_num: int = 1, @@ -68,7 +70,6 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) - self._batched_op = True # this is a batched OP self.aug_num = aug_num if aug_num >= 10: diff --git a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py index 38523b4b5..97d32b121 100644 --- a/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_audio_mapper.py @@ -29,6 +29,8 @@ class VideoCaptioningFromAudioMapper(Mapper): Qwen-Audio model. """ + _batched_op = True + def __init__(self, keep_original_sample: bool = True, *args, **kwargs): """ Initialization method. @@ -41,7 +43,7 @@ def __init__(self, keep_original_sample: bool = True, *args, **kwargs): :param kwargs: extra args """ super().__init__(*args, **kwargs) - self._batched_op = True + self.keep_original_sample = keep_original_sample self.extra_args = kwargs diff --git a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py index 7ef01a098..2e3471506 100644 --- a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py @@ -39,6 +39,8 @@ class VideoCaptioningFromFramesMapper(Mapper): an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string.""" + _batched_op = True + def __init__( self, hf_img2seq='Salesforce/blip2-opt-2.7b', @@ -111,7 +113,6 @@ def __init__( """ super().__init__(*args, **kwargs) - self._batched_op = True self._accelerator = 'cuda' if keep_candidate_mode not in [ diff --git a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py index d03dc6482..7cf1f1b8c 100644 --- a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py @@ -49,6 +49,8 @@ class VideoCaptioningFromSummarizerMapper(Mapper): texts (captions from video/audio/frames, tags from audio/frames, ...) """ + _batched_op = True + def __init__(self, hf_summarizer: str = None, consider_video_caption_from_video: bool = True, @@ -107,7 +109,7 @@ def __init__(self, :param kwargs: extra args """ super().__init__(*args, **kwargs) - self._batched_op = True + self.keep_original_sample = keep_original_sample self.extra_args = kwargs self._accelerator = 'cuda' diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py index 9dc5e34e6..3a20c20c2 100644 --- a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py @@ -38,6 +38,8 @@ class VideoCaptioningFromVideoMapper(Mapper): """Mapper to generate samples whose captions are generated based on a video-to-text model and sampled video frame.""" + _batched_op = True + def __init__( self, hf_video_blip='kpyu/video-blip-opt-2.7b-ego4d', @@ -111,7 +113,6 @@ def __init__( """ super().__init__(*args, **kwargs) - self._batched_op = True self._accelerator = 'cuda' if keep_candidate_mode not in [ diff --git a/data_juicer/ops/mapper/video_face_blur_mapper.py b/data_juicer/ops/mapper/video_face_blur_mapper.py index c623d1c13..05de74cd6 100644 --- a/data_juicer/ops/mapper/video_face_blur_mapper.py +++ b/data_juicer/ops/mapper/video_face_blur_mapper.py @@ -1,6 +1,7 @@ import av from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.mm_utils import (load_data_with_context, load_video, pil_to_opencv, process_each_frame) @@ -68,8 +69,12 @@ def __init__(self, def process(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + loaded_video_keys = sample[self.video_key] sample, videos = load_data_with_context(sample, context, loaded_video_keys, load_video) @@ -90,6 +95,12 @@ def process(self, sample, context=False): if not context: video.close() + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(loaded_video_keys): + if sample[Fields.source_file][i] != value: + if processed_video_keys[value] != value: + sample[Fields.source_file][i] = value + sample[self.video_key] = [ processed_video_keys[key] for key in loaded_video_keys ] diff --git a/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py b/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py index 39a04da0a..6a5a38bcc 100644 --- a/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py +++ b/data_juicer/ops/mapper/video_ffmpeg_wrapped_mapper.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints @@ -50,8 +51,12 @@ def __init__( def process(self, sample): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + if self.filter_name is None: return sample @@ -71,5 +76,11 @@ def process(self, sample): overwrite_output=self.overwrite_output) processed[video_key] = output_key + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(loaded_video_keys): + if sample[Fields.source_file][i] != value: + if processed[value] != value: + sample[Fields.source_file][i] = value + sample[self.video_key] = [processed[key] for key in loaded_video_keys] return sample diff --git a/data_juicer/ops/mapper/video_remove_watermark_mapper.py b/data_juicer/ops/mapper/video_remove_watermark_mapper.py index 53e755936..316c47223 100644 --- a/data_juicer/ops/mapper/video_remove_watermark_mapper.py +++ b/data_juicer/ops/mapper/video_remove_watermark_mapper.py @@ -5,6 +5,7 @@ from jsonargparse.typing import List, PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints from data_juicer.utils.mm_utils import (extract_video_frames_uniformly, @@ -202,8 +203,12 @@ def _clean_watermark(self, frame, watermark_mask): def process(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + loaded_video_keys = sample[self.video_key] sample, videos = load_data_with_context(sample, context, loaded_video_keys, load_video) @@ -230,5 +235,11 @@ def process_frame_func(frame): for vid_key in videos: videos[vid_key].close() + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(sample[self.video_key]): + if sample[Fields.source_file][i] != value: + if loaded_video_keys[i] != value: + sample[Fields.source_file][i] = value + sample[self.video_key] = loaded_video_keys return sample diff --git a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py index fa1de22d6..20b969438 100644 --- a/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py +++ b/data_juicer/ops/mapper/video_resize_aspect_ratio_mapper.py @@ -3,6 +3,7 @@ from fractions import Fraction from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints from data_juicer.utils.mm_utils import load_video @@ -102,8 +103,12 @@ def __init__( def process(self, sample): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + loaded_video_keys = sample[self.video_key] for index, video_key in enumerate(loaded_video_keys): @@ -139,5 +144,11 @@ def process(self, sample): stream.run() loaded_video_keys[index] = resized_video_key + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(sample[self.video_key]): + if sample[Fields.source_file][i] != value: + if loaded_video_keys[i] != value: + sample[Fields.source_file][i] = value + sample[self.video_key] = loaded_video_keys return sample diff --git a/data_juicer/ops/mapper/video_resize_resolution_mapper.py b/data_juicer/ops/mapper/video_resize_resolution_mapper.py index 5d026f8ae..03f72f914 100644 --- a/data_juicer/ops/mapper/video_resize_resolution_mapper.py +++ b/data_juicer/ops/mapper/video_resize_resolution_mapper.py @@ -5,6 +5,7 @@ from jsonargparse.typing import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import transfer_filename from data_juicer.utils.logger_utils import HiddenPrints from data_juicer.utils.mm_utils import load_video @@ -86,8 +87,12 @@ def __init__(self, def process(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] return sample + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + loaded_video_keys = sample[self.video_key] for index, video_key in enumerate(loaded_video_keys): @@ -163,5 +168,11 @@ def process(self, sample, context=False): loaded_video_keys[index] = resized_video_key + # when the file is modified, its source file needs to be updated. + for i, value in enumerate(sample[self.video_key]): + if sample[Fields.source_file][i] != value: + if loaded_video_keys[i] != value: + sample[Fields.source_file][i] = value + sample[self.video_key] = loaded_video_keys return sample diff --git a/data_juicer/ops/mapper/video_split_by_duration_mapper.py b/data_juicer/ops/mapper/video_split_by_duration_mapper.py index dbd30b5a4..d7626a54e 100644 --- a/data_juicer/ops/mapper/video_split_by_duration_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_duration_mapper.py @@ -3,6 +3,7 @@ import numpy as np +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (add_suffix_to_filename, transfer_filename) from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, @@ -29,6 +30,8 @@ class VideoSplitByDurationMapper(Mapper): """Mapper to split video by duration. """ + _batched_op = True + def __init__(self, split_duration: float = 10, min_last_split_duration: float = 0, @@ -51,7 +54,7 @@ def __init__(self, """ super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - self._batched_op = True + self.split_duration = split_duration self.min_last_split_duration = min_last_split_duration self.keep_original_sample = keep_original_sample @@ -85,11 +88,16 @@ def _process_single_sample(self, sample): # there is no video in this sample if self.video_key not in sample or sample[ self.video_key] is None or len(sample[self.video_key]) == 0: + sample[Fields.source_file] = [] return [] + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + # the split results split_sample = copy.deepcopy(sample) split_sample[self.text_key] = '' + split_sample[Fields.source_file] = [] # load all video(s) loaded_video_keys = sample[self.video_key] @@ -119,6 +127,8 @@ def _process_single_sample(self, sample): split_video_keys.extend(new_video_keys) place_holders.append(SpecialTokens.video * len(new_video_keys)) + split_sample[Fields.source_file].extend( + [video_key] * len(new_video_keys)) # insert the generated text according to given mode replacer_function = create_replacer(place_holders) @@ -152,5 +162,4 @@ def process(self, samples): res_samples = {} for key in keys: res_samples[key] = [s[key] for s in samples_after_split] - return res_samples diff --git a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py index 7f6eb3dca..4a8d276aa 100644 --- a/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_key_frame_mapper.py @@ -1,6 +1,7 @@ import copy import re +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (add_suffix_to_filename, transfer_filename) from data_juicer.utils.mm_utils import (SpecialTokens, cut_video_by_seconds, @@ -27,6 +28,8 @@ class VideoSplitByKeyFrameMapper(Mapper): """Mapper to split video by key frame. """ + _batched_op = True + def __init__(self, keep_original_sample: bool = True, *args, **kwargs): """ Initialization method. @@ -40,7 +43,7 @@ def __init__(self, keep_original_sample: bool = True, *args, **kwargs): """ super().__init__(*args, **kwargs) self._init_parameters = self.remove_extra_parameters(locals()) - self._batched_op = True + self.keep_original_sample = keep_original_sample self.extra_args = kwargs @@ -68,11 +71,16 @@ def _process_single_sample(self, sample): # there is no video in this sample if self.video_key not in sample or sample[ self.video_key] is None or len(sample[self.video_key]) == 0: + sample[Fields.source_file] = [] return [] + if Fields.source_file not in sample or not sample[Fields.source_file]: + sample[Fields.source_file] = sample[self.video_key] + # the split results split_sample = copy.deepcopy(sample) split_sample[self.text_key] = '' + split_sample[Fields.source_file] = [] # load all video(s) loaded_video_keys = sample[self.video_key] @@ -101,6 +109,8 @@ def _process_single_sample(self, sample): split_video_keys.extend(new_video_keys) place_holders.append(SpecialTokens.video * len(new_video_keys)) + split_sample[Fields.source_file].extend( + [video_key] * len(new_video_keys)) # insert the generated text according to given mode replacer_function = create_replacer(place_holders) diff --git a/data_juicer/ops/mapper/video_split_by_scene_mapper.py b/data_juicer/ops/mapper/video_split_by_scene_mapper.py index 14ce456b6..18a642c12 100644 --- a/data_juicer/ops/mapper/video_split_by_scene_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_scene_mapper.py @@ -5,6 +5,7 @@ from jsonargparse.typing import NonNegativeFloat, NonNegativeInt from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields from data_juicer.utils.file_utils import (add_suffix_to_filename, transfer_filename) from data_juicer.utils.mm_utils import SpecialTokens @@ -84,6 +85,7 @@ def __init__(self, def process(self, sample, context=False): # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: + sample[Fields.source_file] = [] return sample # load videos @@ -137,6 +139,12 @@ def process(self, sample, context=False): sample[self.text_key]) sample[self.text_key] = updated_text + # when the file is modified, its source file needs to be updated. + sample[Fields.source_file] = [] + for value in loaded_video_keys: + sample[Fields.source_file].extend([value] * + len(output_video_keys[value])) + sample[self.video_key] = list( chain.from_iterable( [output_video_keys[key] for key in loaded_video_keys])) diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 479b5a689..0683c3307 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -19,6 +19,9 @@ class Fields(object): video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' + # the name of the original file from which this sample was derived. + source_file = DEFAULT_PREFIX + 'source_file__' + # the name of diretory to store the produced multimodal data multimodal_data_output_dir = DEFAULT_PREFIX + 'produced_data__' diff --git a/data_juicer/utils/process_utils.py b/data_juicer/utils/process_utils.py index 2aa60e19f..8ef012d53 100644 --- a/data_juicer/utils/process_utils.py +++ b/data_juicer/utils/process_utils.py @@ -4,7 +4,7 @@ import psutil from loguru import logger -from data_juicer import cuda_device_count, use_cuda +from data_juicer import cuda_device_count def get_min_cuda_memory(): @@ -22,31 +22,34 @@ def get_min_cuda_memory(): return min_cuda_memory -def calculate_np(num_proc, op, op_name): +def calculate_np(name, + mem_required, + cpu_required, + num_proc=None, + use_cuda=False): """Calculate the optimum number of processes for the given OP""" if num_proc is None: num_proc = psutil.cpu_count() - if use_cuda() and op._accelerator == 'cuda': + if use_cuda: cuda_mem_available = get_min_cuda_memory() / 1024 op_proc = min( num_proc, - math.floor(cuda_mem_available / (op.mem_required + 0.1)) * + math.floor(cuda_mem_available / (mem_required + 0.1)) * cuda_device_count()) - if use_cuda() and op.mem_required == 0: - logger.warning(f'The required cuda memory of Op[{op_name}] ' + if use_cuda and mem_required == 0: + logger.warning(f'The required cuda memory of Op[{name}] ' f'has not been specified. ' f'Please specify the mem_required field in the ' f'config file, or you might encounter CUDA ' f'out of memory error. You can reference ' f'the mem_required field in the ' - f'config_all.yaml file. ') + f'config_all.yaml file.') if op_proc < 1.0: - logger.warning( - f'The required cuda memory:{op.mem_required}GB might ' - f'be more than the available cuda memory:' - f'{cuda_mem_available}GB.' - f'This Op [{op_name}] might ' - f'require more resource to run.') + logger.warning(f'The required cuda memory:{mem_required}GB might ' + f'be more than the available cuda memory:' + f'{cuda_mem_available}GB.' + f'This Op[{name}] might ' + f'require more resource to run.') op_proc = max(op_proc, 1) return op_proc else: @@ -54,15 +57,15 @@ def calculate_np(num_proc, op, op_name): cpu_available = psutil.cpu_count() mem_available = psutil.virtual_memory().available mem_available = mem_available / 1024**3 - op_proc = min(op_proc, math.floor(cpu_available / op.cpu_required)) + op_proc = min(op_proc, math.floor(cpu_available / cpu_required)) op_proc = min(op_proc, - math.floor(mem_available / (op.mem_required + 0.1))) + math.floor(mem_available / (mem_required + 0.1))) if op_proc < 1.0: - logger.warning(f'The required CPU number:{op.cpu_required} ' - f'and memory:{op.mem_required}GB might ' + logger.warning(f'The required CPU number:{cpu_required} ' + f'and memory:{mem_required}GB might ' f'be more than the available CPU:{cpu_available} ' f'and memory :{mem_available}GB.' - f'This Op [{op_name}] might ' + f'This Op [{name}] might ' f'require more resource to run.') op_proc = max(op_proc, 1) return op_proc diff --git a/data_juicer/utils/registry.py b/data_juicer/utils/registry.py index 8847ae2d4..38700e5b0 100644 --- a/data_juicer/utils/registry.py +++ b/data_juicer/utils/registry.py @@ -84,6 +84,7 @@ def _register_module(self, module_name=None, module_cls=None, force=False): f'{module_name} is already registered in {self._name}') self._modules[module_name] = module_cls + module_cls._name = module_name def register_module(self, module_name: str = None, diff --git a/data_juicer/utils/unittest_utils.py b/data_juicer/utils/unittest_utils.py index b9d18dbf1..724e4628d 100644 --- a/data_juicer/utils/unittest_utils.py +++ b/data_juicer/utils/unittest_utils.py @@ -2,11 +2,28 @@ import shutil import unittest +import numpy +import ray.data as rd + +from data_juicer.core.data import DJDataset, NestedDataset +from data_juicer.core.ray_data import RayDataset from data_juicer.utils.registry import Registry SKIPPED_TESTS = Registry('SkippedTests') +def TEST_TAG(*tags): + """Tags for test case. + Currently, `standalone`, `ray` are supported. + """ + + def decorator(func): + setattr(func, '__test_tags__', tags) + return func + + return decorator + + class DataJuicerTestCaseBase(unittest.TestCase): @classmethod @@ -32,3 +49,50 @@ def tearDownClass(cls, hf_model_name=None) -> None: if os.path.exists(transformers.TRANSFORMERS_CACHE): print('CLEAN all TRANSFORMERS_CACHE') shutil.rmtree(transformers.TRANSFORMERS_CACHE) + + def generate_dataset(self, data) -> DJDataset: + """Generate dataset for a specific executor. + + Args: + type (str, optional): "standalone" or "ray". + Defaults to "standalone". + """ + current_tag = getattr(self, 'current_tag', 'standalone') + if current_tag.startswith('standalone'): + return NestedDataset.from_list(data) + elif current_tag.startswith('ray'): + dataset = rd.from_items(data) + return RayDataset(dataset) + else: + raise ValueError('Unsupported type') + + def run_single_op(self, dataset: DJDataset, op, column_names): + """Run operator in the specific executor.""" + current_tag = getattr(self, 'current_tag', 'standalone') + dataset = dataset.process(op) + if current_tag.startswith('standalone'): + dataset = dataset.select_columns(column_names=column_names) + return dataset.to_list() + elif current_tag.startswith('ray'): + dataset = dataset.data.to_pandas().get(column_names) + if dataset is None: + return [] + return dataset.to_dict(orient='records') + else: + raise ValueError('Unsupported type') + + def assertDatasetEqual(self, first, second): + + def convert_record(rec): + for key in rec.keys(): + # Convert incomparable `list` to comparable `tuple` + if isinstance(rec[key], numpy.ndarray) or isinstance( + rec[key], list): + rec[key] = tuple(rec[key]) + return rec + + first = [convert_record(d) for d in first] + second = [convert_record(d) for d in second] + first = sorted(first, key=lambda x: tuple(sorted(x.items()))) + second = sorted(second, key=lambda x: tuple(sorted(x.items()))) + return self.assertEqual(first, second) diff --git a/docs/DeveloperGuide.md b/docs/DeveloperGuide.md index e887b0358..7940ed0c0 100644 --- a/docs/DeveloperGuide.md +++ b/docs/DeveloperGuide.md @@ -126,19 +126,20 @@ class StatsKeys(object): # ... (same as above) ``` - - If the operator processes data in batches rather than a single sample, it is necessary to declare `self._batched_op = True`. + - If the operator processes data in batches rather than a single sample, it is necessary to declare `_batched_op = True`. ```python # ... (import some other libraries) OP_NAME = 'image_diffusion_mapper' @OPERATORS.register_module(OP_NAME) @LOADED_IMAGES.register_module(OP_NAME) class ImageDiffusionMapper(Mapper): + _batched_op = True + def __init__(self, # ... (OP parameters) *args, **kwargs): super().__init__(*args, **kwargs) - self._batched_op = True def process(self, samples): # ... (some codes) diff --git a/docs/DeveloperGuide_ZH.md b/docs/DeveloperGuide_ZH.md index 9046258b1..9ec85a5ce 100644 --- a/docs/DeveloperGuide_ZH.md +++ b/docs/DeveloperGuide_ZH.md @@ -121,19 +121,20 @@ class StatsKeys(object): # ... (same as above) ``` - - 如果算子批量处理数据,输入不是一个样本而是一个batch,需要声明`self._batched_op = True`。 + - 如果算子批量处理数据,输入不是一个样本而是一个batch,需要声明`_batched_op = True`。 ```python # ... (import some other libraries) OP_NAME = 'image_diffusion_mapper' @OPERATORS.register_module(OP_NAME) @LOADED_IMAGES.register_module(OP_NAME) class ImageDiffusionMapper(Mapper): + _batched_op = True + def __init__(self, # ... (OP parameters) *args, **kwargs): super().__init__(*args, **kwargs) - self._batched_op = True def process(self, samples): # ... (some codes) diff --git a/environments/dist_requires.txt b/environments/dist_requires.txt index 559ed3ec9..4060a654f 100644 --- a/environments/dist_requires.txt +++ b/environments/dist_requires.txt @@ -1,2 +1,2 @@ -ray==2.10.0 +ray==2.31.0 redis>=5.0.0 diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index b8b46867b..180c6973d 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -46,11 +46,9 @@ def test_yaml_cfg_file(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, - 'batched_op': False, } }, 'nested dict load fail, for nonparametric op') self.assertDictEqual( @@ -63,11 +61,10 @@ def test_yaml_cfg_file(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, } }, 'nested dict load fail, un-expected internal value') @@ -130,11 +127,10 @@ def test_mixture_cfg(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, } }) self.assertDictEqual( @@ -147,11 +143,10 @@ def test_mixture_cfg(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, } }) self.assertDictEqual( @@ -164,11 +159,10 @@ def test_mixture_cfg(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, } }) self.assertDictEqual( @@ -181,11 +175,10 @@ def test_mixture_cfg(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, } }) self.assertDictEqual( @@ -198,11 +191,10 @@ def test_mixture_cfg(self): 'audio_key': 'audios', 'video_key': 'videos', 'accelerator': 'cpu', - 'spec_numprocs': 0, + 'num_proc': 0, 'stats_export_path': None, 'cpu_required': 1, 'mem_required': 0, - 'use_actor': False, } }) @@ -213,7 +205,7 @@ def test_op_params_parsing(self): base_class_params = { 'text_key', 'image_key', 'audio_key', 'video_key', 'accelerator', - 'spec_numprocs', 'cpu_required', 'mem_required', 'use_actor', + 'num_proc', 'cpu_required', 'mem_required', } parser = ArgumentParser(default_env=True, default_config_files=None) diff --git a/tests/format/test_unify_format.py b/tests/format/test_unify_format.py index 52b87493d..b83cd775f 100644 --- a/tests/format/test_unify_format.py +++ b/tests/format/test_unify_format.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.format.formatter import load_dataset, unify_format from data_juicer.utils.constant import Fields diff --git a/tests/ops/deduplicator/test_document_deduplicator.py b/tests/ops/deduplicator/test_document_deduplicator.py index 5a37a2e91..c24b11bfc 100644 --- a/tests/ops/deduplicator/test_document_deduplicator.py +++ b/tests/ops/deduplicator/test_document_deduplicator.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.deduplicator.document_deduplicator import \ DocumentDeduplicator diff --git a/tests/ops/deduplicator/test_document_minhash_deduplicator.py b/tests/ops/deduplicator/test_document_minhash_deduplicator.py index 5190ed1e4..9d427ac7f 100644 --- a/tests/ops/deduplicator/test_document_minhash_deduplicator.py +++ b/tests/ops/deduplicator/test_document_minhash_deduplicator.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.deduplicator.document_minhash_deduplicator import \ DocumentMinhashDeduplicator diff --git a/tests/ops/deduplicator/test_document_simhash_deduplicator.py b/tests/ops/deduplicator/test_document_simhash_deduplicator.py index ddde50e82..9b90f275d 100644 --- a/tests/ops/deduplicator/test_document_simhash_deduplicator.py +++ b/tests/ops/deduplicator/test_document_simhash_deduplicator.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.deduplicator.document_simhash_deduplicator import \ DocumentSimhashDeduplicator diff --git a/tests/ops/deduplicator/test_image_deduplicator.py b/tests/ops/deduplicator/test_image_deduplicator.py index 53c85758d..31a048d65 100644 --- a/tests/ops/deduplicator/test_image_deduplicator.py +++ b/tests/ops/deduplicator/test_image_deduplicator.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.deduplicator.image_deduplicator import ImageDeduplicator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase diff --git a/tests/ops/deduplicator/test_video_deduplicator.py b/tests/ops/deduplicator/test_video_deduplicator.py index 9541e0464..97ec5d933 100644 --- a/tests/ops/deduplicator/test_video_deduplicator.py +++ b/tests/ops/deduplicator/test_video_deduplicator.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.deduplicator.video_deduplicator import VideoDeduplicator from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase diff --git a/tests/ops/filter/test_alphanumeric_filter.py b/tests/ops/filter/test_alphanumeric_filter.py index efca696c2..d4ea828c0 100644 --- a/tests/ops/filter/test_alphanumeric_filter.py +++ b/tests/ops/filter/test_alphanumeric_filter.py @@ -1,27 +1,15 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.alphanumeric_filter import AlphanumericFilter from data_juicer.utils.constant import Fields -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG class AlphanumericFilterTest(DataJuicerTestCaseBase): - def _run_alphanumeric_filter(self, dataset: Dataset, target_list, op): - if Fields.stats not in dataset.features: - # TODO: - # this is a temp solution, - # only add stats when calling filter op - dataset = dataset.add_column(name=Fields.stats, - column=[{}] * dataset.num_rows) - dataset = dataset.map(op.compute_stats) - dataset = dataset.filter(op.process) - dataset = dataset.select_columns(column_names=['text']) - res_list = dataset.to_list() - self.assertEqual(res_list, target_list) - + @TEST_TAG("standalone", "ray") def test_case(self): ds_list = [{ @@ -50,10 +38,12 @@ def test_case(self): }, { 'text': 'emoji表情测试下😊,😸31231\n' }] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AlphanumericFilter(min_ratio=0.2, max_ratio=0.9) - self._run_alphanumeric_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, ["text"]) + self.assertDatasetEqual(result, tgt_list) + @TEST_TAG("standalone", "ray") def test_token_case(self): ds_list = [{ @@ -76,9 +66,10 @@ def test_token_case(self): }, { 'text': 'Do you need a cup of coffee?' }] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AlphanumericFilter(tokenization=True, min_ratio=1.5) - self._run_alphanumeric_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, ["text"]) + self.assertDatasetEqual(result, tgt_list) if __name__ == '__main__': diff --git a/tests/ops/filter/test_audio_duration_filter.py b/tests/ops/filter/test_audio_duration_filter.py index 91a39bfd8..d336e9b10 100644 --- a/tests/ops/filter/test_audio_duration_filter.py +++ b/tests/ops/filter/test_audio_duration_filter.py @@ -1,11 +1,11 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.audio_duration_filter import AudioDurationFilter from data_juicer.utils.constant import Fields -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG class AudioDurationFilterTest(DataJuicerTestCaseBase): @@ -30,6 +30,7 @@ def _run_audio_duration_filter(self, res_list = dataset.to_list() self.assertEqual(res_list, target_list) + @TEST_TAG("standalone", "ray") def test_default_filter(self): ds_list = [{ @@ -46,10 +47,13 @@ def test_default_filter(self): }, { 'audios': [self.aud3_path] }] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter() - self._run_audio_duration_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) + + @TEST_TAG("standalone", "ray") def test_filter_long_audios(self): ds_list = [{ @@ -60,10 +64,12 @@ def test_filter_long_audios(self): 'audios': [self.aud3_path] }] tgt_list = [{'audios': [self.aud1_path]}] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter(max_duration=10) - self._run_audio_duration_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) + @TEST_TAG("standalone", "ray") def test_filter_short_audios(self): ds_list = [{ @@ -74,10 +80,12 @@ def test_filter_short_audios(self): 'audios': [self.aud3_path] }] tgt_list = [{'audios': [self.aud3_path]}] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter(min_duration=60) - self._run_audio_duration_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) + @TEST_TAG("standalone", "ray") def test_filter_audios_within_range(self): ds_list = [{ @@ -88,12 +96,13 @@ def test_filter_audios_within_range(self): 'audios': [self.aud3_path] }] tgt_list = [{'audios': [self.aud2_path]}] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter(min_duration=10, max_duration=20) - self._run_audio_duration_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) + @TEST_TAG("standalone", "ray") def test_any(self): - ds_list = [{ 'audios': [self.aud1_path, self.aud2_path] }, { @@ -106,12 +115,14 @@ def test_any(self): }, { 'audios': [self.aud2_path, self.aud3_path] }] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter(min_duration=10, max_duration=20, any_or_all='any') - self._run_audio_duration_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) + @TEST_TAG("standalone", "ray") def test_all(self): ds_list = [{ @@ -122,12 +133,14 @@ def test_all(self): 'audios': [self.aud1_path, self.aud3_path] }] tgt_list = [] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter(min_duration=10, max_duration=20, any_or_all='all') - self._run_audio_duration_filter(dataset, tgt_list, op) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) + @TEST_TAG("standalone", "ray") def test_filter_in_parallel(self): ds_list = [{ @@ -138,9 +151,10 @@ def test_filter_in_parallel(self): 'audios': [self.aud3_path] }] tgt_list = [{'audios': [self.aud2_path]}] - dataset = Dataset.from_list(ds_list) + dataset = self.generate_dataset(ds_list) op = AudioDurationFilter(min_duration=10, max_duration=20) - self._run_audio_duration_filter(dataset, tgt_list, op, np=2) + result = self.run_single_op(dataset, op, [op.audio_key]) + self.assertDatasetEqual(result, tgt_list) if __name__ == '__main__': diff --git a/tests/ops/filter/test_audio_nmf_snr_filter.py b/tests/ops/filter/test_audio_nmf_snr_filter.py index 728c43f39..1cc010b2f 100644 --- a/tests/ops/filter/test_audio_nmf_snr_filter.py +++ b/tests/ops/filter/test_audio_nmf_snr_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.audio_nmf_snr_filter import AudioNMFSNRFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_audio_size_filter.py b/tests/ops/filter/test_audio_size_filter.py index 00b4158d7..74b26a5df 100644 --- a/tests/ops/filter/test_audio_size_filter.py +++ b/tests/ops/filter/test_audio_size_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.audio_size_filter import AudioSizeFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_average_line_length_filter.py b/tests/ops/filter/test_average_line_length_filter.py index a1c39e702..e294cb77e 100644 --- a/tests/ops/filter/test_average_line_length_filter.py +++ b/tests/ops/filter/test_average_line_length_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.average_line_length_filter import \ AverageLineLengthFilter diff --git a/tests/ops/filter/test_character_repetition_filter.py b/tests/ops/filter/test_character_repetition_filter.py index 85133c133..77c1ac1d2 100644 --- a/tests/ops/filter/test_character_repetition_filter.py +++ b/tests/ops/filter/test_character_repetition_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.character_repetition_filter import \ CharacterRepetitionFilter diff --git a/tests/ops/filter/test_flagged_words_filter.py b/tests/ops/filter/test_flagged_words_filter.py index e346eb0f5..b07ef685f 100644 --- a/tests/ops/filter/test_flagged_words_filter.py +++ b/tests/ops/filter/test_flagged_words_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.flagged_words_filter import FlaggedWordFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_image_aesthetics_filter.py b/tests/ops/filter/test_image_aesthetics_filter.py index 9b5277cde..e20f9d2c6 100644 --- a/tests/ops/filter/test_image_aesthetics_filter.py +++ b/tests/ops/filter/test_image_aesthetics_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_aesthetics_filter import \ ImageAestheticsFilter diff --git a/tests/ops/filter/test_image_aspect_ratio_filter.py b/tests/ops/filter/test_image_aspect_ratio_filter.py index d8d3df0ea..cbb5998c8 100644 --- a/tests/ops/filter/test_image_aspect_ratio_filter.py +++ b/tests/ops/filter/test_image_aspect_ratio_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_aspect_ratio_filter import \ ImageAspectRatioFilter diff --git a/tests/ops/filter/test_image_face_ratio_filter.py b/tests/ops/filter/test_image_face_ratio_filter.py index 2a2327b8f..e602a775f 100644 --- a/tests/ops/filter/test_image_face_ratio_filter.py +++ b/tests/ops/filter/test_image_face_ratio_filter.py @@ -1,13 +1,14 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_face_ratio_filter import ImageFaceRatioFilter from data_juicer.utils.constant import Fields -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +@SKIPPED_TESTS.register_module() class ImageFaceRatioFilterTest(DataJuicerTestCaseBase): maxDiff = None diff --git a/tests/ops/filter/test_image_nsfw_filter.py b/tests/ops/filter/test_image_nsfw_filter.py index 46c68561d..0a588e272 100644 --- a/tests/ops/filter/test_image_nsfw_filter.py +++ b/tests/ops/filter/test_image_nsfw_filter.py @@ -3,7 +3,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer import _cuda_device_count from data_juicer.ops.filter.image_nsfw_filter import ImageNSFWFilter diff --git a/tests/ops/filter/test_image_shape_filter.py b/tests/ops/filter/test_image_shape_filter.py index e7e5deaed..0a1e25e58 100644 --- a/tests/ops/filter/test_image_shape_filter.py +++ b/tests/ops/filter/test_image_shape_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_shape_filter import ImageShapeFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_image_size_filter.py b/tests/ops/filter/test_image_size_filter.py index fcc5e3e76..7d05d0828 100644 --- a/tests/ops/filter/test_image_size_filter.py +++ b/tests/ops/filter/test_image_size_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_size_filter import ImageSizeFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_image_text_matching_filter.py b/tests/ops/filter/test_image_text_matching_filter.py index 27e022181..0551da254 100644 --- a/tests/ops/filter/test_image_text_matching_filter.py +++ b/tests/ops/filter/test_image_text_matching_filter.py @@ -3,7 +3,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_text_matching_filter import \ ImageTextMatchingFilter diff --git a/tests/ops/filter/test_image_text_similarity_filter.py b/tests/ops/filter/test_image_text_similarity_filter.py index 549ee3137..a373b6033 100644 --- a/tests/ops/filter/test_image_text_similarity_filter.py +++ b/tests/ops/filter/test_image_text_similarity_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.image_text_similarity_filter import \ ImageTextSimilarityFilter diff --git a/tests/ops/filter/test_image_watermark_filter.py b/tests/ops/filter/test_image_watermark_filter.py index def87307a..01ed2e0dc 100644 --- a/tests/ops/filter/test_image_watermark_filter.py +++ b/tests/ops/filter/test_image_watermark_filter.py @@ -3,7 +3,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer import _cuda_device_count from data_juicer.ops.filter.image_watermark_filter import ImageWatermarkFilter diff --git a/tests/ops/filter/test_language_id_score_filter.py b/tests/ops/filter/test_language_id_score_filter.py index 21d71ceb5..8fa9fd8c6 100644 --- a/tests/ops/filter/test_language_id_score_filter.py +++ b/tests/ops/filter/test_language_id_score_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.language_id_score_filter import \ LanguageIDScoreFilter diff --git a/tests/ops/filter/test_maximum_line_length_filter.py b/tests/ops/filter/test_maximum_line_length_filter.py index ef8a6d33e..6f1cab7f6 100644 --- a/tests/ops/filter/test_maximum_line_length_filter.py +++ b/tests/ops/filter/test_maximum_line_length_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.maximum_line_length_filter import \ MaximumLineLengthFilter diff --git a/tests/ops/filter/test_perplexity_filter.py b/tests/ops/filter/test_perplexity_filter.py index 114bdb307..07e87d17c 100644 --- a/tests/ops/filter/test_perplexity_filter.py +++ b/tests/ops/filter/test_perplexity_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.perplexity_filter import PerplexityFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_phrase_grounding_recall_filter.py b/tests/ops/filter/test_phrase_grounding_recall_filter.py index ab0fa256c..e865c2f22 100644 --- a/tests/ops/filter/test_phrase_grounding_recall_filter.py +++ b/tests/ops/filter/test_phrase_grounding_recall_filter.py @@ -3,7 +3,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.phrase_grounding_recall_filter import \ PhraseGroundingRecallFilter diff --git a/tests/ops/filter/test_special_characters_filter.py b/tests/ops/filter/test_special_characters_filter.py index 4ea505968..b1dd8632e 100644 --- a/tests/ops/filter/test_special_characters_filter.py +++ b/tests/ops/filter/test_special_characters_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.special_characters_filter import \ SpecialCharactersFilter diff --git a/tests/ops/filter/test_specified_field_filter.py b/tests/ops/filter/test_specified_field_filter.py index 3086e2b00..7d51e61de 100644 --- a/tests/ops/filter/test_specified_field_filter.py +++ b/tests/ops/filter/test_specified_field_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.specified_field_filter import SpecifiedFieldFilter from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase diff --git a/tests/ops/filter/test_specified_numeric_field_filter.py b/tests/ops/filter/test_specified_numeric_field_filter.py index c580f6905..008fe2d69 100644 --- a/tests/ops/filter/test_specified_numeric_field_filter.py +++ b/tests/ops/filter/test_specified_numeric_field_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.specified_numeric_field_filter import \ SpecifiedNumericFieldFilter diff --git a/tests/ops/filter/test_stop_words_filter.py b/tests/ops/filter/test_stop_words_filter.py index 8772b6960..467a04440 100644 --- a/tests/ops/filter/test_stop_words_filter.py +++ b/tests/ops/filter/test_stop_words_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.stopwords_filter import StopWordsFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_suffix_filter.py b/tests/ops/filter/test_suffix_filter.py index 48980c120..fed28594e 100644 --- a/tests/ops/filter/test_suffix_filter.py +++ b/tests/ops/filter/test_suffix_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.suffix_filter import SuffixFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_text_action_filter.py b/tests/ops/filter/test_text_action_filter.py index 78b40dfad..378e51eec 100644 --- a/tests/ops/filter/test_text_action_filter.py +++ b/tests/ops/filter/test_text_action_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.text_action_filter import TextActionFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_text_entity_dependency_filter.py b/tests/ops/filter/test_text_entity_dependency_filter.py index 6247318f7..29caa96c8 100644 --- a/tests/ops/filter/test_text_entity_dependency_filter.py +++ b/tests/ops/filter/test_text_entity_dependency_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.text_entity_dependency_filter import \ TextEntityDependencyFilter diff --git a/tests/ops/filter/test_text_length_filter.py b/tests/ops/filter/test_text_length_filter.py index cb5df982b..67efb6c60 100644 --- a/tests/ops/filter/test_text_length_filter.py +++ b/tests/ops/filter/test_text_length_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.text_length_filter import TextLengthFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_token_num_filter.py b/tests/ops/filter/test_token_num_filter.py index 5ee78bab2..b33aa73c1 100644 --- a/tests/ops/filter/test_token_num_filter.py +++ b/tests/ops/filter/test_token_num_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.token_num_filter import TokenNumFilter from data_juicer.utils.constant import Fields, StatsKeys diff --git a/tests/ops/filter/test_video_aesthetics_filter.py b/tests/ops/filter/test_video_aesthetics_filter.py index 48942c0d3..7f8098853 100644 --- a/tests/ops/filter/test_video_aesthetics_filter.py +++ b/tests/ops/filter/test_video_aesthetics_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_aesthetics_filter import \ VideoAestheticsFilter diff --git a/tests/ops/filter/test_video_aspect_ratio_filter.py b/tests/ops/filter/test_video_aspect_ratio_filter.py index b07844097..9b34becdc 100644 --- a/tests/ops/filter/test_video_aspect_ratio_filter.py +++ b/tests/ops/filter/test_video_aspect_ratio_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_aspect_ratio_filter import \ VideoAspectRatioFilter diff --git a/tests/ops/filter/test_video_duration_filter.py b/tests/ops/filter/test_video_duration_filter.py index 2954836bf..38f19f87f 100644 --- a/tests/ops/filter/test_video_duration_filter.py +++ b/tests/ops/filter/test_video_duration_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_duration_filter import VideoDurationFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_video_frames_text_similarity_filter.py b/tests/ops/filter/test_video_frames_text_similarity_filter.py index 04e7355e5..9c4978a50 100644 --- a/tests/ops/filter/test_video_frames_text_similarity_filter.py +++ b/tests/ops/filter/test_video_frames_text_similarity_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_frames_text_similarity_filter import \ VideoFramesTextSimilarityFilter diff --git a/tests/ops/filter/test_video_motion_score_filter.py b/tests/ops/filter/test_video_motion_score_filter.py index 0c7ce3f5d..d8c8367e4 100644 --- a/tests/ops/filter/test_video_motion_score_filter.py +++ b/tests/ops/filter/test_video_motion_score_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_motion_score_filter import \ VideoMotionScoreFilter diff --git a/tests/ops/filter/test_video_nsfw_filter.py b/tests/ops/filter/test_video_nsfw_filter.py index 4c70f589a..3c713407d 100644 --- a/tests/ops/filter/test_video_nsfw_filter.py +++ b/tests/ops/filter/test_video_nsfw_filter.py @@ -3,7 +3,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer import _cuda_device_count from data_juicer.ops.filter.video_nsfw_filter import VideoNSFWFilter diff --git a/tests/ops/filter/test_video_ocr_area_ratio_filter.py b/tests/ops/filter/test_video_ocr_area_ratio_filter.py index 909086f23..9884ab1cf 100644 --- a/tests/ops/filter/test_video_ocr_area_ratio_filter.py +++ b/tests/ops/filter/test_video_ocr_area_ratio_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_ocr_area_ratio_filter import \ VideoOcrAreaRatioFilter diff --git a/tests/ops/filter/test_video_resolution_filter.py b/tests/ops/filter/test_video_resolution_filter.py index 210662a3e..f35d6cff4 100644 --- a/tests/ops/filter/test_video_resolution_filter.py +++ b/tests/ops/filter/test_video_resolution_filter.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_resolution_filter import \ VideoResolutionFilter diff --git a/tests/ops/filter/test_video_tagging_from_frames_filter.py b/tests/ops/filter/test_video_tagging_from_frames_filter.py index 242426d58..c16b07d4d 100644 --- a/tests/ops/filter/test_video_tagging_from_frames_filter.py +++ b/tests/ops/filter/test_video_tagging_from_frames_filter.py @@ -2,7 +2,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.video_tagging_from_frames_filter import \ VideoTaggingFromFramesFilter from data_juicer.utils.mm_utils import SpecialTokens @@ -21,7 +21,7 @@ def _run_video_tagging_from_frames_filter(self, source_list, target_list, num_proc=1): - dataset = NestedDataset.from_list(source_list) + dataset = Dataset.from_list(source_list) dataset = dataset.map(op.compute_stats) dataset = dataset.filter(op.process) dataset = dataset.select_columns(column_names=['text', 'videos']) diff --git a/tests/ops/filter/test_video_watermark_filter.py b/tests/ops/filter/test_video_watermark_filter.py index ed7e6fd94..aca75131f 100644 --- a/tests/ops/filter/test_video_watermark_filter.py +++ b/tests/ops/filter/test_video_watermark_filter.py @@ -3,7 +3,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer import _cuda_device_count from data_juicer.ops.filter.video_watermark_filter import VideoWatermarkFilter diff --git a/tests/ops/filter/test_word_num_filter.py b/tests/ops/filter/test_word_num_filter.py index 6b4522c5e..0d53a164d 100644 --- a/tests/ops/filter/test_word_num_filter.py +++ b/tests/ops/filter/test_word_num_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.words_num_filter import WordsNumFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/filter/test_word_repetition_filter.py b/tests/ops/filter/test_word_repetition_filter.py index cf5f02330..f59576ef8 100644 --- a/tests/ops/filter/test_word_repetition_filter.py +++ b/tests/ops/filter/test_word_repetition_filter.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.filter.word_repetition_filter import WordRepetitionFilter from data_juicer.utils.constant import Fields diff --git a/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py b/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py index 4ee4fdd61..bd74e608c 100644 --- a/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py +++ b/tests/ops/mapper/test_audio_ffmpeg_wrapped_mapper.py @@ -2,7 +2,7 @@ import unittest import librosa -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.audio_ffmpeg_wrapped_mapper import \ AudioFFmpegWrappedMapper diff --git a/tests/ops/mapper/test_image_blur_mapper.py b/tests/ops/mapper/test_image_blur_mapper.py index 632c1978b..98046c867 100644 --- a/tests/ops/mapper/test_image_blur_mapper.py +++ b/tests/ops/mapper/test_image_blur_mapper.py @@ -2,7 +2,7 @@ import unittest import numpy as np -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.image_blur_mapper import ImageBlurMapper from data_juicer.utils.mm_utils import load_image diff --git a/tests/ops/mapper/test_image_captioning_mapper.py b/tests/ops/mapper/test_image_captioning_mapper.py index 56d48621f..c4c3d1e3e 100644 --- a/tests/ops/mapper/test_image_captioning_mapper.py +++ b/tests/ops/mapper/test_image_captioning_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.image_captioning_mapper import \ ImageCaptioningMapper from data_juicer.utils.mm_utils import SpecialTokens @@ -27,7 +27,7 @@ def tearDownClass(cls) -> None: super().tearDownClass(cls.hf_img2seq) def _run_mapper(self, - dataset: NestedDataset, + dataset: Dataset, op, num_proc=1, caption_num=0): @@ -48,7 +48,7 @@ def test_no_eoc_special_token(self): 'images': [self.img3_path] }] caption_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any') @@ -69,7 +69,7 @@ def test_eoc_special_token(self): } ] caption_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any') @@ -85,7 +85,7 @@ def test_multi_candidate_keep_random_any(self): 'images': [self.img3_path] }] caption_num = 4 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any') @@ -101,7 +101,7 @@ def test_multi_candidate_keep_all(self): 'images': [self.img3_path] }] caption_num = 4 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='all') @@ -118,7 +118,7 @@ def test_multi_candidate_keep_similar_one(self): 'images': [self.img3_path] }] caption_num = 4 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='similar_one_simhash') @@ -130,7 +130,7 @@ def test_multi_process(self): 'images': [self.cat_path] }] * 10 caption_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any') @@ -146,7 +146,7 @@ def test_no_eoc_special_token_remove_original_sample(self): 'images': [self.img3_path] }] caption_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any', @@ -168,7 +168,7 @@ def test_eoc_special_token_remove_original_sample(self): } ] caption_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any', @@ -185,7 +185,7 @@ def test_multi_candidate_keep_random_any_remove_original_sample(self): 'images': [self.img3_path] }] caption_num = 4 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any', @@ -202,7 +202,7 @@ def test_multi_candidate_keep_all_remove_original_sample(self): 'images': [self.img3_path] }] caption_num = 4 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='all', @@ -218,7 +218,7 @@ def test_multi_candidate_keep_similar_one_remove_original_sample(self): 'images': [self.img3_path] }] caption_num = 4 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='similar_one_simhash', @@ -231,7 +231,7 @@ def test_multi_process_remove_original_sample(self): 'images': [self.cat_path] }] * 10 caption_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageCaptioningMapper(hf_img2seq=self.hf_img2seq, caption_num=caption_num, keep_candidate_mode='random_any', diff --git a/tests/ops/mapper/test_image_diffusion_mapper.py b/tests/ops/mapper/test_image_diffusion_mapper.py index 77db34a99..ad241732f 100644 --- a/tests/ops/mapper/test_image_diffusion_mapper.py +++ b/tests/ops/mapper/test_image_diffusion_mapper.py @@ -3,7 +3,7 @@ import unittest from data_juicer import _cuda_device_count -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.image_diffusion_mapper import ImageDiffusionMapper from data_juicer.utils.mm_utils import SpecialTokens from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, @@ -33,7 +33,7 @@ def tearDownClass(cls) -> None: super().tearDownClass(cls.hf_img2seq) def _run_mapper(self, - dataset: NestedDataset, + dataset: Dataset, op, move_to_dir, num_proc=1, @@ -61,7 +61,7 @@ def test_for_strength(self): 'images': [self.cat_path] }] aug_num = 3 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, strength=1.0, aug_num=aug_num, @@ -81,7 +81,7 @@ def test_for_given_caption_list(self): }] aug_num = 2 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, aug_num=aug_num, keep_original_sample=False, @@ -103,7 +103,7 @@ def test_for_given_caption_string(self): }] aug_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, aug_num=aug_num, keep_original_sample=False, @@ -125,7 +125,7 @@ def test_for_no_given_caption(self): }] aug_num = 2 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, aug_num=aug_num, keep_original_sample=False, @@ -147,7 +147,7 @@ def test_for_fp16_given_caption_string(self): }] aug_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, torch_dtype='fp16', revision='fp16', @@ -171,7 +171,7 @@ def test_for_multi_process_given_caption_string(self): }] aug_num = 1 - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = ImageDiffusionMapper(hf_diffusion=self.hf_diffusion, aug_num=aug_num, keep_original_sample=False, diff --git a/tests/ops/mapper/test_image_face_blur_mapper.py b/tests/ops/mapper/test_image_face_blur_mapper.py index bc8aa6eaf..fcfe7e275 100644 --- a/tests/ops/mapper/test_image_face_blur_mapper.py +++ b/tests/ops/mapper/test_image_face_blur_mapper.py @@ -2,7 +2,7 @@ import shutil import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.image_face_blur_mapper import ImageFaceBlurMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase diff --git a/tests/ops/mapper/test_nlpaug_en_mapper.py b/tests/ops/mapper/test_nlpaug_en_mapper.py index 5451ffd7c..fecd5c378 100644 --- a/tests/ops/mapper/test_nlpaug_en_mapper.py +++ b/tests/ops/mapper/test_nlpaug_en_mapper.py @@ -2,7 +2,7 @@ import unittest -from data_juicer.core import NestedDataset +from data_juicer.core import NestedDataset as Dataset from data_juicer.ops.mapper.nlpaug_en_mapper import NlpaugEnMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,7 +10,7 @@ class NlpaugEnMapperTest(DataJuicerTestCaseBase): def setUp(self): - self.samples = NestedDataset.from_dict({ + self.samples = Dataset.from_dict({ 'text': [ 'I am a deep learning engineer. I love LLM.', 'A short test with numbers 2023' diff --git a/tests/ops/mapper/test_nlpcda_zh_mapper.py b/tests/ops/mapper/test_nlpcda_zh_mapper.py index 80aa2bf84..3624a9c35 100644 --- a/tests/ops/mapper/test_nlpcda_zh_mapper.py +++ b/tests/ops/mapper/test_nlpcda_zh_mapper.py @@ -2,7 +2,7 @@ import unittest -from data_juicer.core import NestedDataset +from data_juicer.core import NestedDataset as Dataset from data_juicer.ops.mapper.nlpcda_zh_mapper import NlpcdaZhMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase @@ -10,7 +10,7 @@ class NlpaugEnMapperTest(DataJuicerTestCaseBase): def setUp(self): - self.samples = NestedDataset.from_dict({ + self.samples = Dataset.from_dict({ 'text': ['这里一共有5种不同的数据增强方法', '这是不带数字的测试样例'], 'meta': ['meta information', 'meta information without numbers'], }) diff --git a/tests/ops/mapper/test_video_captioning_from_audio_mapper.py b/tests/ops/mapper/test_video_captioning_from_audio_mapper.py index 3a842bab8..caadeb97b 100644 --- a/tests/ops/mapper/test_video_captioning_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_captioning_from_audio_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_captioning_from_audio_mapper import \ VideoCaptioningFromAudioMapper from data_juicer.utils.mm_utils import SpecialTokens @@ -35,7 +35,7 @@ def _count_generated_caption_num(text): cap_num += len(caps) return vid_num, cap_num - def _run_op(self, dataset: NestedDataset, caption_num, op, np=1): + def _run_op(self, dataset: Dataset, caption_num, op, np=1): dataset = dataset.map(op.process, num_proc=np) text_list = dataset.select_columns(column_names=['text']).to_list() for txt in text_list: @@ -55,7 +55,7 @@ def test_default_params(self): 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。', 'videos': [self.vid3_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromAudioMapper() self._run_op(dataset, len(dataset) * 2, op) @@ -75,7 +75,7 @@ def test_with_eoc(self): f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', 'videos': [self.vid3_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromAudioMapper() self._run_op(dataset, len(dataset) * 2, op) @@ -95,7 +95,7 @@ def test_no_original_samples(self): f'两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.video} {SpecialTokens.eoc}', 'videos': [self.vid3_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromAudioMapper(keep_original_sample=False) self._run_op(dataset, len(dataset), op) @@ -113,7 +113,7 @@ def test_multi_chunk_samples(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid3_path, self.vid1_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromAudioMapper() self._run_op(dataset, len(dataset) * 2, op) @@ -135,7 +135,7 @@ def test_multi_video_samples(self): 'videos': [self.vid3_path, self.vid1_path, self.vid2_path, self.vid1_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromAudioMapper() self._run_op(dataset, len(dataset) * 2, op) @@ -151,7 +151,7 @@ def test_parallel(self): 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。', 'videos': [self.vid3_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromAudioMapper() self._run_op(dataset, len(dataset) * 2, op, np=2) diff --git a/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py b/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py index 955ecda4c..79f8037b9 100644 --- a/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py +++ b/tests/ops/mapper/test_video_captioning_from_summarizer_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_captioning_from_summarizer_mapper import \ VideoCaptioningFromSummarizerMapper from data_juicer.utils.mm_utils import SpecialTokens @@ -35,7 +35,7 @@ def _count_generated_caption_num(text): cap_num += len(caps) return vid_num, cap_num - def _run_op(self, dataset: NestedDataset, caption_num, op, np=1): + def _run_op(self, dataset: Dataset, caption_num, op, np=1): dataset = dataset.map(op.process, num_proc=np) text_list = dataset.select_columns(column_names=['text']).to_list() for txt in text_list: @@ -55,7 +55,7 @@ def test_default_params(self): 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。', 'videos': [self.vid3_path] }] - dataset = NestedDataset.from_list(ds_list) + dataset = Dataset.from_list(ds_list) op = VideoCaptioningFromSummarizerMapper() self._run_op(dataset, len(dataset) * 2, op) diff --git a/tests/ops/mapper/test_video_face_blur_mapper.py b/tests/ops/mapper/test_video_face_blur_mapper.py index f95531fbf..905e754fa 100644 --- a/tests/ops/mapper/test_video_face_blur_mapper.py +++ b/tests/ops/mapper/test_video_face_blur_mapper.py @@ -2,7 +2,7 @@ import shutil import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_face_blur_mapper import VideoFaceBlurMapper from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase diff --git a/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py b/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py index 1071bd864..53e39b820 100644 --- a/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py +++ b/tests/ops/mapper/test_video_ffmpeg_wrapped_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_ffmpeg_wrapped_mapper import \ VideoFFmpegWrappedMapper diff --git a/tests/ops/mapper/test_video_remove_watermark_mapper.py b/tests/ops/mapper/test_video_remove_watermark_mapper.py index fd1b32887..0cfefa76f 100644 --- a/tests/ops/mapper/test_video_remove_watermark_mapper.py +++ b/tests/ops/mapper/test_video_remove_watermark_mapper.py @@ -2,7 +2,7 @@ import shutil import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_remove_watermark_mapper import \ VideoRemoveWatermarkMapper diff --git a/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py b/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py index 3db841646..2cd09e86e 100644 --- a/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py +++ b/tests/ops/mapper/test_video_resize_aspect_ratio_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_resize_aspect_ratio_mapper import \ VideoResizeAspectRatioMapper diff --git a/tests/ops/mapper/test_video_resize_resolution_mapper.py b/tests/ops/mapper/test_video_resize_resolution_mapper.py index 24b22307f..9bedffb61 100644 --- a/tests/ops/mapper/test_video_resize_resolution_mapper.py +++ b/tests/ops/mapper/test_video_resize_resolution_mapper.py @@ -2,7 +2,7 @@ import unittest import ffmpeg -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_resize_resolution_mapper import \ VideoResizeResolutionMapper diff --git a/tests/ops/mapper/test_video_split_by_duration_mapper.py b/tests/ops/mapper/test_video_split_by_duration_mapper.py index 43089dfa7..c7efb9e74 100644 --- a/tests/ops/mapper/test_video_split_by_duration_mapper.py +++ b/tests/ops/mapper/test_video_split_by_duration_mapper.py @@ -3,7 +3,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_split_by_duration_mapper import \ VideoSplitByDurationMapper from data_juicer.utils.file_utils import add_suffix_to_filename @@ -28,7 +28,10 @@ def _get_res_list(self, dataset, source_list): # for keep_original_sample=True if set(output_paths) <= set(origin_paths): - res_list.append(sample) + res_list.append({ + 'text': sample['text'], + 'videos': sample['videos'] + }) continue source = source_list[idx] @@ -59,7 +62,7 @@ def _run_video_split_by_duration_mapper(self, source_list, target_list, num_proc=1): - dataset = NestedDataset.from_list(source_list) + dataset = Dataset.from_list(source_list) dataset = dataset.map(op.process, num_proc=num_proc) res_list = self._get_res_list(dataset, source_list) self.assertEqual(res_list, target_list) diff --git a/tests/ops/mapper/test_video_split_by_key_frame_mapper.py b/tests/ops/mapper/test_video_split_by_key_frame_mapper.py index 997ae9ed8..29b881298 100644 --- a/tests/ops/mapper/test_video_split_by_key_frame_mapper.py +++ b/tests/ops/mapper/test_video_split_by_key_frame_mapper.py @@ -3,12 +3,12 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_split_by_key_frame_mapper import \ VideoSplitByKeyFrameMapper from data_juicer.utils.file_utils import add_suffix_to_filename from data_juicer.utils.mm_utils import SpecialTokens -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, TEST_TAG class VideoSplitByKeyFrameMapperTest(DataJuicerTestCaseBase): @@ -20,15 +20,20 @@ class VideoSplitByKeyFrameMapperTest(DataJuicerTestCaseBase): vid3_path = os.path.join(data_path, 'video3.mp4') def _get_res_list(self, dataset, source_list): + dataset = sorted(dataset, key=lambda x: x["id"]) + source_list = sorted(source_list, key=lambda x: x["id"]) res_list = [] origin_paths = [self.vid1_path, self.vid2_path, self.vid3_path] idx = 0 - for sample in dataset.to_list(): + for sample in dataset: output_paths = sample['videos'] - # for keep_original_sample=True if set(output_paths) <= set(origin_paths): - res_list.append(sample) + res_list.append({ + 'id': sample['id'], + 'text': sample['text'], + 'videos': sample['videos'] + }) continue source = source_list[idx] @@ -48,6 +53,7 @@ def _get_res_list(self, dataset, source_list): split_frames_nums.append(cnt) res_list.append({ + 'id': sample['id'], 'text': sample['text'], 'split_frames_num': split_frames_nums }) @@ -59,33 +65,41 @@ def _run_video_split_by_key_frame_mapper(self, source_list, target_list, num_proc=1): - dataset = NestedDataset.from_list(source_list) - dataset = dataset.map(op.process, num_proc=num_proc) + dataset = self.generate_dataset(source_list) + # TODO: use num_proc + dataset = self.run_single_op(dataset, op, ["id", "text", "videos"]) res_list = self._get_res_list(dataset, source_list) - self.assertEqual(res_list, target_list) + self.assertDatasetEqual(res_list, target_list) + @TEST_TAG("standalone", "ray") def test(self): ds_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path] }, { + 'id': 1, 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path] }, { + 'id': 2, 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path] }] tgt_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', 'split_frames_num': [3] }, { + 'id': 1, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'split_frames_num': [3] }, { + 'id': 2, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'split_frames_num': [6] @@ -93,39 +107,49 @@ def test(self): op = VideoSplitByKeyFrameMapper(keep_original_sample=False) self._run_video_split_by_key_frame_mapper(op, ds_list, tgt_list) + @TEST_TAG("standalone", "ray") def test_keep_ori_sample(self): ds_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path] }, { + 'id': 1, 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path] }, { + 'id': 2, 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path] }] tgt_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path] }, { + 'id': 0, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', 'split_frames_num': [3] }, { + 'id': 1, 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path] }, { + 'id': 1, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'split_frames_num': [3] }, { + 'id': 2, 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path] }, { + 'id': 2, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'split_frames_num': [6] @@ -133,28 +157,35 @@ def test_keep_ori_sample(self): op = VideoSplitByKeyFrameMapper() self._run_video_split_by_key_frame_mapper(op, ds_list, tgt_list) + @TEST_TAG("standalone", "ray") def test_multi_process(self): ds_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path] }, { + 'id': 1, 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path] }, { + 'id': 2, 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path] }] tgt_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}', 'split_frames_num': [3] }, { + 'id': 1, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'split_frames_num': [3] }, { + 'id': 2, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'split_frames_num': [6] @@ -165,29 +196,37 @@ def test_multi_process(self): tgt_list, num_proc=2) + @TEST_TAG("standalone", "ray") def test_multi_chunk(self): + # wrong because different order ds_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。', 'videos': [self.vid1_path, self.vid2_path] }, { + 'id': 1, 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid2_path, self.vid3_path] }, { + 'id': 2, 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid1_path, self.vid3_path] }] tgt_list = [{ + 'id': 0, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'split_frames_num': [3, 3] }, { + 'id': 1, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'split_frames_num': [3, 6] }, { + 'id': 2, 'text': f'{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'split_frames_num': [3, 6] diff --git a/tests/ops/mapper/test_video_split_by_scene_mapper.py b/tests/ops/mapper/test_video_split_by_scene_mapper.py index 0637d2beb..dbbc32553 100644 --- a/tests/ops/mapper/test_video_split_by_scene_mapper.py +++ b/tests/ops/mapper/test_video_split_by_scene_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_split_by_scene_mapper import \ VideoSplitBySceneMapper diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py index 045af79b1..a81fb51c7 100644 --- a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py @@ -1,7 +1,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_tagging_from_audio_mapper import \ VideoTaggingFromAudioMapper from data_juicer.utils.constant import Fields @@ -30,7 +30,7 @@ def _run_video_tagging_from_audio_mapper(self, source_list, target_list, num_proc=1): - dataset = NestedDataset.from_list(source_list) + dataset = Dataset.from_list(source_list) dataset = dataset.map(op.process, num_proc=num_proc) res_list = dataset.select_columns([Fields.video_audio_tags ])[Fields.video_audio_tags] diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py index b33c1d867..b34c45151 100644 --- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -2,7 +2,7 @@ import os import unittest -from data_juicer.core.data import NestedDataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.mapper.video_tagging_from_frames_mapper import \ VideoTaggingFromFramesMapper from data_juicer.utils.constant import Fields @@ -22,7 +22,7 @@ def _run_video_tagging_from_frames_mapper(self, source_list, target_list, num_proc=1): - dataset = NestedDataset.from_list(source_list) + dataset = Dataset.from_list(source_list) dataset = dataset.map(op.process, num_proc=num_proc) res_list = dataset.to_list() self.assertEqual(res_list, target_list) diff --git a/tests/ops/selector/test_frequency_specified_field_selector.py b/tests/ops/selector/test_frequency_specified_field_selector.py index 4593e83ef..5bf584293 100644 --- a/tests/ops/selector/test_frequency_specified_field_selector.py +++ b/tests/ops/selector/test_frequency_specified_field_selector.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.selector.frequency_specified_field_selector import \ FrequencySpecifiedFieldSelector diff --git a/tests/ops/selector/test_topk_specified_field_selector.py b/tests/ops/selector/test_topk_specified_field_selector.py index f10129ded..4a3e35e75 100644 --- a/tests/ops/selector/test_topk_specified_field_selector.py +++ b/tests/ops/selector/test_topk_specified_field_selector.py @@ -1,6 +1,6 @@ import unittest -from datasets import Dataset +from data_juicer.core.data import NestedDataset as Dataset from data_juicer.ops.selector.topk_specified_field_selector import \ TopkSpecifiedFieldSelector diff --git a/tests/run.py b/tests/run.py index 8ff91e459..81028ee01 100644 --- a/tests/run.py +++ b/tests/run.py @@ -19,7 +19,9 @@ sys.path.append(file_dir) parser = argparse.ArgumentParser('test runner') -parser.add_argument('--list_tests', action='store_true', help='list all tests') +parser.add_argument('--tag', choices=["standalone", "ray"], + default="standalone", + help="the tag of tests being run") parser.add_argument('--pattern', default='test_*.py', help='test file pattern') parser.add_argument('--test_dir', default='tests', @@ -27,45 +29,47 @@ args = parser.parse_args() -def gather_test_cases(test_dir, pattern, list_tests): - test_suite = unittest.TestSuite() - discover = unittest.defaultTestLoader.discover(test_dir, - pattern=pattern, - top_level_dir=None) +class TaggedTestLoader(unittest.TestLoader): + def __init__(self, tag="standalone"): + super().__init__() + self.tag = tag + + def loadTestsFromTestCase(self, testCaseClass): + # set tag to testcase class + setattr(testCaseClass, 'current_tag', self.tag) + test_names = self.getTestCaseNames(testCaseClass) + loaded_suite = self.suiteClass() + for test_name in test_names: + test_case = testCaseClass(test_name) + test_method = getattr(test_case, test_name) + if self.tag in getattr(test_method, '__test_tags__', ["standalone"]): + loaded_suite.addTest(test_case) + return loaded_suite + +def gather_test_cases(test_dir, pattern, tag): + test_to_run = unittest.TestSuite() + test_loader = TaggedTestLoader(tag) + discover = test_loader.discover(test_dir, pattern=pattern, top_level_dir=None) print(f'These tests will be skipped due to some reasons: ' f'{SKIPPED_TESTS.modules}') for suite_discovered in discover: - - for test_case in suite_discovered: - logger.info(f'Prepare for test [{test_case}]') - # filter out those tests that need to be skipped - filtered_test_suite = unittest.TestSuite() - for tc in test_case: - if type(tc) in SKIPPED_TESTS.modules.values(): + for test_suite in suite_discovered: + for test_case in test_suite: + if type(test_case) in SKIPPED_TESTS.modules.values(): continue - filtered_test_suite.addTest(tc) - if filtered_test_suite.countTestCases() == 0: - continue - - test_suite.addTest(test_case) - if hasattr(test_case, '__iter__'): - for subcase in test_case: - if list_tests: - print(subcase) - else: - if list_tests: - print(test_case) - return test_suite + logger.info(f'Add test case [{test_case._testMethodName}]' + f' from {test_case.__class__.__name__}') + test_to_run.addTest(test_case) + return test_to_run def main(): runner = unittest.TextTestRunner() test_suite = gather_test_cases(os.path.abspath(args.test_dir), - args.pattern, args.list_tests) - if not args.list_tests: - res = runner.run(test_suite) - if not res.wasSuccessful(): - exit(1) + args.pattern, args.tag) + res = runner.run(test_suite) + if not res.wasSuccessful(): + exit(1) if __name__ == '__main__':