diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 8d37a5ae6..470865693 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -92,7 +92,7 @@ process: keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default. prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided. prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default. - mem_required: '20GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + mem_required: '16GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched - image_diffusion_mapper: # generate images by diffusion model hf_diffusion: 'CompVis/stable-diffusion-v1-4' # stable diffusion model name on huggingface to generate image torch_dtype: 'fp32' # the floating point type used to load the diffusion model. Can be one of ['fp32', 'fp16', 'bf16'] @@ -103,7 +103,7 @@ process: keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated images in the final datasets and the original images will be removed. It's True in default. caption_key: null # the key name of fields in samples to store captions for each images, the caption guide the diffusion model to produce what the image is hf_img2seq: 'Salesforce/blip2-opt-2.7b' # model name on huggingface to generate caption if caption_key is null - mem_required: '25GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched + mem_required: '8GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched - image_face_blur_mapper: # blur faces detected in images blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel @@ -335,7 +335,7 @@ process: lang: en # compute perplexity in what language max_ppl: 1500 # the max perplexity score to filter text - phrase_grounding_recall_filter: # filter samples according to the locating recall of phrases extracted from text in the images. - hf_owlvit: openai/clip-vit-base-patch32 # name of used Hugging Face Owl-ViT + hf_owlvit: google/owlvit-base-patch32 # name of used Hugging Face Owl-ViT min_recall: 0.1 # the min phrase grounding recall of filter range max_recall: 1.0 # the max phrase grounding recall of filter range horizontal_flip: false # flip image horizontally (left to right). diff --git a/data_juicer/__init__.py b/data_juicer/__init__.py index 615cf52f7..e0f765a29 100644 --- a/data_juicer/__init__.py +++ b/data_juicer/__init__.py @@ -6,9 +6,13 @@ import multiprocess as mp from loguru import logger +# allow loading truncated images for some too large images. +from PIL import ImageFile from data_juicer.utils.availability_utils import _is_package_available +ImageFile.LOAD_TRUNCATED_IMAGES = True + # For now, only INFO will be shown. Later the severity level will be changed # when setup_logger is called to initialize the logger. logger.remove() diff --git a/data_juicer/ops/deduplicator/image_deduplicator.py b/data_juicer/ops/deduplicator/image_deduplicator.py index d61e18cea..ab3d7fbc9 100644 --- a/data_juicer/ops/deduplicator/image_deduplicator.py +++ b/data_juicer/ops/deduplicator/image_deduplicator.py @@ -64,6 +64,9 @@ def __init__(self, self.text_dedup_op = DocumentDeduplicator(**kwargs) def compute_hash(self, sample, context=False): + # get hash of text first + if self.consider_text: + sample = self.text_dedup_op.compute_hash(sample) # check if it's computed already if HashKeys.imagehash in sample: return sample @@ -82,8 +85,6 @@ def compute_hash(self, sample, context=False): for key in images: sample[HashKeys.imagehash] += self.hasher.encode_image( image_array=np.array(images[key])) - if self.consider_text: - sample = self.text_dedup_op.compute_hash(sample) return sample def process(self, dataset, show_num=0): diff --git a/data_juicer/ops/deduplicator/video_deduplicator.py b/data_juicer/ops/deduplicator/video_deduplicator.py index 8073fec2e..17b84c0ba 100644 --- a/data_juicer/ops/deduplicator/video_deduplicator.py +++ b/data_juicer/ops/deduplicator/video_deduplicator.py @@ -36,6 +36,9 @@ def __init__(self, consider_text: bool = False, *args, **kwargs): self.text_dedup_op = DocumentDeduplicator(**kwargs) def compute_hash(self, sample, context=False): + # get hash of text first + if self.consider_text: + sample = self.text_dedup_op.compute_hash(sample) # check if it's computed already if HashKeys.videohash in sample: return sample @@ -59,8 +62,6 @@ def compute_hash(self, sample, context=False): md5_hash.update(bytes(packet)) sample[HashKeys.videohash] = md5_hash.hexdigest() - if self.consider_text: - sample = self.text_dedup_op.compute_hash(sample) return sample def process(self, dataset, show_num=0): diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index bef411e2a..e8612db2d 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -554,7 +554,7 @@ def move_to_cuda(model, rank): for module in model: if callable(getattr(module, 'to', None)): - logger.info( + logger.debug( f'Moving {module.__class__.__name__} to CUDA device {rank}') module.to(f'cuda:{rank}') diff --git a/tests/ops/deduplicator/test_image_deduplicator.py b/tests/ops/deduplicator/test_image_deduplicator.py index 53c85758d..20a27ccd6 100644 --- a/tests/ops/deduplicator/test_image_deduplicator.py +++ b/tests/ops/deduplicator/test_image_deduplicator.py @@ -32,8 +32,9 @@ class ImageDeduplicatorTest(DataJuicerTestCaseBase): os.symlink(img6_path, img7_path) def _run_image_deduplicator(self, dataset: Dataset, target_list, op): - key_list = [op.image_key, op.text_key] \ - if op.consider_text else [op.image_key] + expected_keys = [op.image_key, op.text_key] + key_list = [key for key in expected_keys + if len(target_list) > 0 and key in target_list[0]] dataset = dataset.map(op.compute_hash) dataset, _ = op.process(dataset) @@ -292,6 +293,73 @@ def test_8(self): op = ImageDeduplicator(method='ahash') self._run_image_deduplicator(dataset, tgt_list, op) + def test_no_image(self): + + ds_list = [{ + 'images': [], + 'text': 'text1', + }, { + 'images': [], + 'text': 'text2', + }, { + 'images': [self.img7_path], + 'text': ' text6', + }, { + 'images': [self.img6_path], + 'text': ' text6', + }] + tgt_list = [{ + 'images': [], + 'text': 'text1', + }, { + 'images': [], + 'text': 'text2', + }, { + 'images': [self.img7_path], + 'text': ' text6', + }] + dataset = Dataset.from_list(ds_list) + op = ImageDeduplicator() + self._run_image_deduplicator(dataset, tgt_list, op) + + def test_no_image_consider_text(self): + + ds_list = [{ + 'images': [], + 'text': 'text1', + }, { + 'images': [], + 'text': 'text2', + }, { + 'images': [], + 'text': 'text1', + }, { + 'images': [], + 'text': 'text3', + }, { + 'images': [self.img7_path], + 'text': ' text6', + }, { + 'images': [self.img6_path], + 'text': ' text6', + }] + tgt_list = [{ + 'images': [], + 'text': 'text1', + }, { + 'images': [], + 'text': 'text2', + }, { + 'images': [], + 'text': 'text3', + }, { + 'images': [self.img7_path], + 'text': ' text6', + }] + dataset = Dataset.from_list(ds_list) + op = ImageDeduplicator(consider_text=True) + self._run_image_deduplicator(dataset, tgt_list, op) + if __name__ == '__main__': unittest.main() diff --git a/tests/ops/deduplicator/test_video_deduplicator.py b/tests/ops/deduplicator/test_video_deduplicator.py index 9541e0464..4e4c9bb30 100644 --- a/tests/ops/deduplicator/test_video_deduplicator.py +++ b/tests/ops/deduplicator/test_video_deduplicator.py @@ -32,8 +32,9 @@ class VideoDeduplicatorTest(DataJuicerTestCaseBase): os.symlink(video6_path, video7_path) def _run_video_deduplicator(self, dataset: Dataset, target_list, op): - key_list = [op.video_key, op.text_key] \ - if op.consider_text else [op.video_key] + expected_keys = [op.video_key, op.text_key] + key_list = [key for key in expected_keys + if len(target_list) > 0 and key in target_list[0]] dataset = dataset.map(op.compute_hash) dataset, _ = op.process(dataset) @@ -224,6 +225,94 @@ def test_5(self): op = VideoDeduplicator() self._run_video_deduplicator(dataset, tgt_list, op) + def test_no_video(self): + + ds_list = [{ + 'videos': [], + 'text': '