Skip to content

Commit

Permalink
Fix several bugs for image OPs (#351)
Browse files Browse the repository at this point in the history
* * fix bug occurs when there is no multimodal data and considering text

* * allow analyzer to use cuda
* fix a bug in format conversion tool
* change the PIL flag to allow loading truncated images

* * fix the default model for phrase_grounding_recall_filter

* * update gpu mem requirements for two OPs
  • Loading branch information
HYLcool authored Jul 12, 2024
1 parent aa6ba47 commit 9801714
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 13 deletions.
6 changes: 3 additions & 3 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ process:
keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated captions in the final datasets and the original captions will be removed. It's True in default.
prompt: null # a string prompt to guide the generation of blip2 model for all samples globally. It's None in default, which means no prompt provided.
prompt_key: null # the key name of fields in samples to store prompts for each sample. It's used for set different prompts for different samples. If it's none, use prompt in parameter "prompt". It's None in default.
mem_required: '20GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
mem_required: '16GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- image_diffusion_mapper: # generate images by diffusion model
hf_diffusion: 'CompVis/stable-diffusion-v1-4' # stable diffusion model name on huggingface to generate image
torch_dtype: 'fp32' # the floating point type used to load the diffusion model. Can be one of ['fp32', 'fp16', 'bf16']
Expand All @@ -103,7 +103,7 @@ process:
keep_original_sample: true # whether to keep the original sample. If it's set to False, there will be only generated images in the final datasets and the original images will be removed. It's True in default.
caption_key: null # the key name of fields in samples to store captions for each images, the caption guide the diffusion model to produce what the image is
hf_img2seq: 'Salesforce/blip2-opt-2.7b' # model name on huggingface to generate caption if caption_key is null
mem_required: '25GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
mem_required: '8GB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- image_face_blur_mapper: # blur faces detected in images
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
radius: 2 # radius of blur kernel
Expand Down Expand Up @@ -335,7 +335,7 @@ process:
lang: en # compute perplexity in what language
max_ppl: 1500 # the max perplexity score to filter text
- phrase_grounding_recall_filter: # filter samples according to the locating recall of phrases extracted from text in the images.
hf_owlvit: openai/clip-vit-base-patch32 # name of used Hugging Face Owl-ViT
hf_owlvit: google/owlvit-base-patch32 # name of used Hugging Face Owl-ViT
min_recall: 0.1 # the min phrase grounding recall of filter range
max_recall: 1.0 # the max phrase grounding recall of filter range
horizontal_flip: false # flip image horizontally (left to right).
Expand Down
4 changes: 4 additions & 0 deletions data_juicer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@

import multiprocess as mp
from loguru import logger
# allow loading truncated images for some too large images.
from PIL import ImageFile

from data_juicer.utils.availability_utils import _is_package_available

ImageFile.LOAD_TRUNCATED_IMAGES = True

# For now, only INFO will be shown. Later the severity level will be changed
# when setup_logger is called to initialize the logger.
logger.remove()
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/deduplicator/image_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def __init__(self,
self.text_dedup_op = DocumentDeduplicator(**kwargs)

def compute_hash(self, sample, context=False):
# get hash of text first
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
# check if it's computed already
if HashKeys.imagehash in sample:
return sample
Expand All @@ -82,8 +85,6 @@ def compute_hash(self, sample, context=False):
for key in images:
sample[HashKeys.imagehash] += self.hasher.encode_image(
image_array=np.array(images[key]))
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
return sample

def process(self, dataset, show_num=0):
Expand Down
5 changes: 3 additions & 2 deletions data_juicer/ops/deduplicator/video_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, consider_text: bool = False, *args, **kwargs):
self.text_dedup_op = DocumentDeduplicator(**kwargs)

def compute_hash(self, sample, context=False):
# get hash of text first
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
# check if it's computed already
if HashKeys.videohash in sample:
return sample
Expand All @@ -59,8 +62,6 @@ def compute_hash(self, sample, context=False):
md5_hash.update(bytes(packet))

sample[HashKeys.videohash] = md5_hash.hexdigest()
if self.consider_text:
sample = self.text_dedup_op.compute_hash(sample)
return sample

def process(self, dataset, show_num=0):
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def move_to_cuda(model, rank):

for module in model:
if callable(getattr(module, 'to', None)):
logger.info(
logger.debug(
f'Moving {module.__class__.__name__} to CUDA device {rank}')
module.to(f'cuda:{rank}')

Expand Down
72 changes: 70 additions & 2 deletions tests/ops/deduplicator/test_image_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class ImageDeduplicatorTest(DataJuicerTestCaseBase):
os.symlink(img6_path, img7_path)

def _run_image_deduplicator(self, dataset: Dataset, target_list, op):
key_list = [op.image_key, op.text_key] \
if op.consider_text else [op.image_key]
expected_keys = [op.image_key, op.text_key]
key_list = [key for key in expected_keys
if len(target_list) > 0 and key in target_list[0]]

dataset = dataset.map(op.compute_hash)
dataset, _ = op.process(dataset)
Expand Down Expand Up @@ -292,6 +293,73 @@ def test_8(self):
op = ImageDeduplicator(method='ahash')
self._run_image_deduplicator(dataset, tgt_list, op)

def test_no_image(self):

ds_list = [{
'images': [],
'text': 'text1',
}, {
'images': [],
'text': 'text2',
}, {
'images': [self.img7_path],
'text': '<image> text6',
}, {
'images': [self.img6_path],
'text': '<image> text6',
}]
tgt_list = [{
'images': [],
'text': 'text1',
}, {
'images': [],
'text': 'text2',
}, {
'images': [self.img7_path],
'text': '<image> text6',
}]
dataset = Dataset.from_list(ds_list)
op = ImageDeduplicator()
self._run_image_deduplicator(dataset, tgt_list, op)

def test_no_image_consider_text(self):

ds_list = [{
'images': [],
'text': 'text1',
}, {
'images': [],
'text': 'text2',
}, {
'images': [],
'text': 'text1',
}, {
'images': [],
'text': 'text3',
}, {
'images': [self.img7_path],
'text': '<image> text6',
}, {
'images': [self.img6_path],
'text': '<image> text6',
}]
tgt_list = [{
'images': [],
'text': 'text1',
}, {
'images': [],
'text': 'text2',
}, {
'images': [],
'text': 'text3',
}, {
'images': [self.img7_path],
'text': '<image> text6',
}]
dataset = Dataset.from_list(ds_list)
op = ImageDeduplicator(consider_text=True)
self._run_image_deduplicator(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()
93 changes: 91 additions & 2 deletions tests/ops/deduplicator/test_video_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class VideoDeduplicatorTest(DataJuicerTestCaseBase):
os.symlink(video6_path, video7_path)

def _run_video_deduplicator(self, dataset: Dataset, target_list, op):
key_list = [op.video_key, op.text_key] \
if op.consider_text else [op.video_key]
expected_keys = [op.video_key, op.text_key]
key_list = [key for key in expected_keys
if len(target_list) > 0 and key in target_list[0]]

dataset = dataset.map(op.compute_hash)
dataset, _ = op.process(dataset)
Expand Down Expand Up @@ -224,6 +225,94 @@ def test_5(self):
op = VideoDeduplicator()
self._run_video_deduplicator(dataset, tgt_list, op)

def test_no_video(self):

ds_list = [{
'videos': [],
'text': '<video> text1'
}, {
'videos': [self.video2_path],
'text': '<video> text2'
}, {
'videos': [self.video3_path],
'text': '<video> text3'
}, {
'videos': [],
'text': '<video> text1'
}, {
'videos': [self.video5_path],
'text': '<video> text5'
}, {
'videos': [],
'text': '<video> text3'
}, {
'videos': [self.video7_path],
'text': '<video> text7'
}]
tgt_list = [{
'videos': [],
'text': '<video> text1'
}, {
'videos': [self.video2_path],
'text': '<video> text2'
}, {
'videos': [self.video3_path],
'text': '<video> text3'
}, {
'videos': [],
'text': '<video> text1'
}, {
'videos': [],
'text': '<video> text3'
}]
dataset = Dataset.from_list(ds_list)
op = VideoDeduplicator()
self._run_video_deduplicator(dataset, tgt_list, op)

def test_no_video_consider_text(self):

ds_list = [{
'videos': [],
'text': '<video> text1'
}, {
'videos': [self.video2_path],
'text': '<video> text2'
}, {
'videos': [self.video3_path],
'text': '<video> text3'
}, {
'videos': [],
'text': '<video> text1'
}, {
'videos': [self.video5_path],
'text': '<video> text5'
}, {
'videos': [],
'text': '<video> text3'
}, {
'videos': [self.video7_path],
'text': '<video> text3'
}]
tgt_list = [{
'videos': [],
'text': '<video> text1'
}, {
'videos': [self.video2_path],
'text': '<video> text2'
}, {
'videos': [self.video3_path],
'text': '<video> text3'
}, {
'videos': [self.video5_path],
'text': '<video> text5'
}, {
'videos': [],
'text': '<video> text3'
}]
dataset = Dataset.from_list(ds_list)
op = VideoDeduplicator(consider_text=True)
self._run_video_deduplicator(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def clean_sentence(sentence, round):
f'dataset. Please check and fix it and '
f'retry.')
# need to restore questions for samples with only captions
ori_convs = ori_ds[id2idx[str(id)]]['conversations']
ori_convs = ori_ds[id2idx[str(sid)]]['conversations']
conversations.append(ori_convs[0]) # add question
conversations.append({
'from': ori_convs[1]['from'],
Expand Down

0 comments on commit 9801714

Please sign in to comment.