Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add text_pair_similarity_filter #405

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,12 @@ process:
- text_length_filter: # filter text with length out of specific range
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
- text_pair_similarity_filter: # filter samples according to the similarity score between the text pair.
hf_clip: 'openai/clip-vit-base-patch32' # model name of the CLIP model on huggingface
min_score: 0.1 # the min similarity score of filter range
max_score: 1.0 # the max similarity score of filter range
text_key_second: None # used to store the other sentence in the text pair
any_or_all: "any" # keep this sample when any/all text pairs meet the filter condition
- token_num_filter: # filter text with total token number out of specific range
hf_tokenizer: EleutherAI/pythia-6.9b-deduped # name of used Hugging Face tokenizer
min_num: 10 # the min number of filter range
Expand Down
16 changes: 9 additions & 7 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
specified_field_filter, specified_numeric_field_filter,
stopwords_filter, suffix_filter, text_action_filter,
text_entity_dependency_filter, text_length_filter,
token_num_filter, video_aesthetics_filter,
video_aspect_ratio_filter, video_duration_filter,
video_frames_text_similarity_filter, video_motion_score_filter,
video_nsfw_filter, video_ocr_area_ratio_filter,
video_resolution_filter, video_tagging_from_frames_filter,
video_watermark_filter, word_repetition_filter,
words_num_filter)
text_pair_similarity_filter, token_num_filter,
video_aesthetics_filter, video_aspect_ratio_filter,
video_duration_filter, video_frames_text_similarity_filter,
video_motion_score_filter, video_nsfw_filter,
video_ocr_area_ratio_filter, video_resolution_filter,
video_tagging_from_frames_filter, video_watermark_filter,
word_repetition_filter, words_num_filter)
from .alphanumeric_filter import AlphanumericFilter
from .audio_duration_filter import AudioDurationFilter
from .audio_nmf_snr_filter import AudioNMFSNRFilter
Expand Down Expand Up @@ -47,6 +47,7 @@
from .text_action_filter import TextActionFilter
from .text_entity_dependency_filter import TextEntityDependencyFilter
from .text_length_filter import TextLengthFilter
from .text_pair_similarity_filter import TextPairSimilarityFilter
from .token_num_filter import TokenNumFilter
from .video_aesthetics_filter import VideoAestheticsFilter
from .video_aspect_ratio_filter import VideoAspectRatioFilter
Expand Down Expand Up @@ -104,6 +105,7 @@
'FlaggedWordFilter',
'WordRepetitionFilter',
'VideoMotionScoreFilter',
'TextPairSimilarityFilter'
]

# yapf: enable
120 changes: 120 additions & 0 deletions data_juicer/ops/filter/text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging

import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.ops.base_op import OPERATORS, Filter
from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.model_utils import get_model, prepare_model

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

OP_NAME = 'text_pair_similarity_filter'

with AvailabilityChecking(['torch', 'transformers'], OP_NAME):

import torch
import transformers # noqa: F401

# avoid hanging when calling clip in multiprocessing
torch.set_num_threads(1)


@OPERATORS.register_module(OP_NAME)
class TextPairSimilarityFilter(Filter):
"""Filter to keep text pairs with similarities between texts
within a specific range."""

_accelerator = 'cuda'

def __init__(self,
hf_clip='openai/clip-vit-base-patch32',
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.1,
max_score: ClosedUnitInterval = 1.0,
text_key_second=None,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.

:param hf_clip: clip model name on huggingface to compute
the similarity between image and text.
:param min_score: The min similarity to keep samples.
:param max_score: The max similarity to keep samples.
:param text_key_second: used to store the other sentence
in the text pair.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.min_score = min_score
self.max_score = max_score
if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')
self.model_key = prepare_model(model_type='huggingface',
pretrained_model_name_or_path=hf_clip,
trust_remote_code=trust_remote_code)
self.text_key_second = text_key_second

def compute_stats(self, sample, rank=None, context=False):

# check if it's computed already
if StatsKeys.text_pair_similarity in sample[Fields.stats]:
return sample

# there is no target text
if self.text_key_second is None:
logger.error('This OP (text_pair_similarity_filter) requires \
processing multiple fields, and you need to specify \
valid `text_key_second`')

# there is no text in this sample
if (self.text_key not in sample or len(sample[self.text_key]) == 0
or self.text_key_second not in sample
or len(sample[self.text_key_second]) == 0):
sample[Fields.stats][StatsKeys.text_pair_similarity] = np.array(
[], dtype=np.float64)
return sample

model, processor = get_model(self.model_key, rank, self.use_cuda())

text1 = sample[self.text_key]
text2 = sample[self.text_key_second]

text_tensors = processor([text1, text2],
padding=True,
return_tensors='pt').to(model.device)
text_features = model.get_text_features(**text_tensors)

similarity = torch.cosine_similarity(text_features[0],
text_features[1],
dim=0)
sample[Fields.stats][StatsKeys.text_pair_similarity] = [similarity]

return sample

def process(self, sample, rank=None):
similarity = sample[Fields.stats][StatsKeys.text_pair_similarity]
if len(similarity) <= 0:
return True

keep_bools = np.array([
self.min_score <= sim_value <= self.max_score
for sim_value in similarity
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class StatsKeysConstant(object):
special_char_ratio = 'special_char_ratio'
stopwords_ratio = 'stopwords_ratio'
text_len = 'text_len'
text_pair_similarity = 'text_pair_similarity'
num_action = 'num_action'
num_dependency_edges = 'num_dependency_edges'
num_token = 'num_token'
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples |
| [ Filter ]( #filter ) | 41 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 42 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |

Expand Down Expand Up @@ -130,6 +130,7 @@ All the specific operators are listed below, each featured with several capabili
| text_action_filter | General | en, zh | Keeps samples containing action verbs in their texts |
| text_entity_dependency_filter | General | en, zh | Keeps samples containing dependency edges for an entity in the dependency tree of the texts |
| text_length_filter | General | en, zh | Keeps samples with total text length within the specified range |
| text_pair_similarity_filter | General | en, zh | Keeps text pairs with text feature cosine similarity within the specified range based on a CLIP model |
| token_num_filter | General | en, zh | Keeps samples with token count within the specified range |
| video_aesthetics_filter | Video | - | Keeps samples whose specified frames have aesthetics scores within the specified range |
| video_aspect_ratio_filter | Video | - | Keeps samples containing videos with aspect ratios within the specified range |
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators_ZH.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
|------------------------------------|:--:|---------------|
| [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 |
| [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 |
| [ Filter ]( #filter ) | 41 | 过滤低质量样本 |
| [ Filter ]( #filter ) | 42 | 过滤低质量样本 |
| [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 |
| [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 |

Expand Down Expand Up @@ -128,6 +128,7 @@ Data-Juicer 中的算子分为以下 5 种类型。
| text_action_filter | General | en, zh | 保留文本部分包含动作的样本 |
| text_entity_dependency_filter | General | en, zh | 保留文本部分的依存树中具有非独立实体的样本 |
| text_length_filter | General | en, zh | 保留总文本长度在指定范围内的样本 |
| text_pair_similarity_filter | General | en, zh | 保留文本特征余弦相似度(基于CLIP模型)在指定范围内的样本 |
| token_num_filter | General | en, zh | 保留token数在指定范围内的样本 |
| video_aspect_ratio_filter | Video | - | 保留包含视频的宽高比在指定范围内的样本 |
| video_duration_filter | Video | - | 保留包含视频的时长在指定范围内的样本 |
Expand Down
66 changes: 66 additions & 0 deletions tests/ops/filter/test_text_pair_similarity_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import unittest

from data_juicer.core.data import NestedDataset as Dataset

from data_juicer.ops.filter.text_pair_similarity_filter import TextPairSimilarityFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import SpecialTokens
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase


class TextPairSimilarityFilterTest(DataJuicerTestCaseBase):

hf_clip = "openai/clip-vit-base-patch32"

text_key = "text"
text_key_second = "target_text"


@classmethod
def tearDownClass(cls) -> None:
super().tearDownClass(cls.hf_clip)

def _run_filter(self, dataset: Dataset, op, num_proc=1):

if Fields.stats not in dataset.features:
# TODO:
# this is a temp solution,
# only add stats when calling filter op
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)

dataset = dataset.map(op.compute_stats,
num_proc=num_proc,
with_rank=True)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.select_columns(column_names=[self.text_key,
self.text_key_second])
res_list = dataset.to_list()
print(res_list)

def test_no_eoc_special_token(self):

ds_list = [{
self.text_key_second: 'a lovely cat',
self.text_key: 'a lovely cat',
}, {
self.text_key_second: 'a lovely cat',
self.text_key: 'a cute cat',
}, {
self.text_key_second: 'a lovely cat',
self.text_key: 'a black dog',
}]


dataset = Dataset.from_list(ds_list)
op = TextPairSimilarityFilter(hf_clip=self.hf_clip,
any_or_all='any',
min_score=0.1,
max_score=0.99,
text_key_second=self.text_key_second)
self._run_filter(dataset, op)


if __name__ == '__main__':
unittest.main()
Loading