Skip to content

Commit

Permalink
add image_face_count_filter (#452)
Browse files Browse the repository at this point in the history
Co-authored-by: unknown <[email protected]>
  • Loading branch information
TobyJasper and unknown authored Oct 18, 2024
1 parent abdbcff commit 539b12d
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 0 deletions.
126 changes: 126 additions & 0 deletions data_juicer/ops/filter/image_face_count_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os

import lazy_loader as lazy
import numpy as np
from loguru import logger

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import (detect_faces, load_data_with_context,
load_image)
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import AUTOINSTALL, OPERATORS, UNFORKABLE, Filter
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_face_count_filter'

cv2 = lazy.load('cv2')


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageFaceCountFilter(Filter):
"""Filter to keep samples with the number of faces within a specific range.
"""

_default_kwargs = {
'scaleFactor': 1.1,
'minNeighbors': 3,
'minSize': None,
'maxSize': None,
}

def __init__(self,
cv_classifier: str = '',
min_face_count: int = 1,
max_face_count: int = 1,
any_or_all: str = 'any',
*args,
**kwargs):
"""
Initialization method.
:param cv_classifier: OpenCV classifier path for face detection.
By default, we will use 'haarcascade_frontalface_alt.xml'.
:param min_face_count: Minimum number of faces required for samples.
:param max_face_count: Maximum number of faces required for samples.
:param any_or_all: Keep this sample with 'any' or 'all' strategy of
all images. 'any': keep this sample if any images meet the
condition. 'all': keep this sample only if all images meet the
condition.
:param args: Extra positional arguments.
:param kwargs: Extra keyword arguments.
"""
super().__init__(*args, **kwargs)
AUTOINSTALL.check(['opencv-python'])

if cv_classifier == '':
cv_classifier = os.path.join(cv2.data.haarcascades,
'haarcascade_frontalface_alt.xml')

self.min_face_count = min_face_count
self.max_face_count = max_face_count

self.extra_kwargs = self._default_kwargs
for key in kwargs:
if key in self.extra_kwargs:
self.extra_kwargs[key] = kwargs[key]

if any_or_all not in ['any', 'all']:
raise ValueError(f'Keep strategy [{any_or_all}] is not supported. '
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

self.model_key = prepare_model(model_type='opencv_classifier',
model_path=cv_classifier)

def compute_stats(self, sample, context=False):
# check if it's computed already
if StatsKeys.face_ratios in sample[Fields.stats]:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[Fields.stats][StatsKeys.face_counts] = np.array(
[], dtype=np.float64)
return sample

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

model = get_model(self.model_key)

# count the number of detected faces in each image
face_counts = {}
try:
for key, image in images.items():
dets = detect_faces(image, model, **self.extra_kwargs)
face_counts[key] = len(dets)
print(f'face counts: {face_counts}')
logger.debug(f'face counts: {face_counts}')
except Exception as e:
logger.exception(e)

sample[Fields.stats][StatsKeys.face_counts] = [
face_counts[key] for key in loaded_image_keys
]
return sample

def process(self, sample):
face_counts = sample[Fields.stats][StatsKeys.face_counts]
if len(face_counts) <= 0:
return True

keep_bools = np.array([
self.min_face_count <= face_count <= self.max_face_count
for face_count in face_counts
])

# different strategies
if self.any:
return keep_bools.any()
else:
return keep_bools.all()
1 change: 1 addition & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class StatsKeysConstant(object):
image_sizes = 'image_sizes'
face_ratios = 'face_ratios'
face_detections = 'face_detections'
face_counts = 'face_counts'
image_aesthetics_scores = 'image_aesthetics_scores'
image_nsfw_score = 'image_nsfw_score'
image_watermark_prob = 'image_watermark_prob'
Expand Down
Binary file added tests/ops/data/img8.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
97 changes: 97 additions & 0 deletions tests/ops/filter/test_image_face_count_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import os
import unittest

from data_juicer.core.data import NestedDataset as Dataset

from data_juicer.ops.filter.image_face_count_filter import ImageFaceCountFilter
from data_juicer.utils.constant import Fields
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS


@SKIPPED_TESTS.register_module()
class ImageFaceCountFilterTest(DataJuicerTestCaseBase):

data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..',
'data')
img1_path = os.path.join(data_path, 'cat.jpg')
img2_path = os.path.join(data_path, 'lena.jpg')
img3_path = os.path.join(data_path, 'img8.jpg')

def _run_face_count_filter(self, dataset: Dataset, target_list, op, num_proc=1):
if Fields.stats not in dataset.features:
dataset = dataset.add_column(name=Fields.stats,
column=[{}] * dataset.num_rows)
dataset = dataset.map(op.compute_stats, num_proc=num_proc)
dataset = dataset.filter(op.process, num_proc=num_proc)
dataset = dataset.remove_columns(Fields.stats)
res_list = dataset.to_list()
self.assertEqual(res_list, target_list)

def test_filter_1(self):
ds_list = [
{'images': [self.img1_path]},
{'images': [self.img2_path]},
{'images': [self.img3_path]}
]
tgt_list = [{'images': [self.img2_path]}]
dataset = Dataset.from_list(ds_list)
op = ImageFaceCountFilter(min_face_count=1, max_face_count=1)
self._run_face_count_filter(dataset, tgt_list, op)

def test_filter_2(self):
ds_list = [
{'images': [self.img1_path]},
{'images': [self.img2_path]},
{'images': [self.img3_path]}
]
tgt_list = [{'images': [self.img2_path]}, {'images': [self.img3_path]}]
dataset = Dataset.from_list(ds_list)
op = ImageFaceCountFilter(min_face_count=1, max_face_count=5)
self._run_face_count_filter(dataset, tgt_list, op)

def test_filter_multi_proc(self):
ds_list = [
{'images': [self.img1_path]},
{'images': [self.img2_path]},
{'images': [self.img3_path]}
]
tgt_list = [{'images': [self.img2_path]}, {'images': [self.img3_path]}]
dataset = Dataset.from_list(ds_list)
op = ImageFaceCountFilter(min_face_count=1, max_face_count=5)
self._run_face_count_filter(dataset, tgt_list, op, num_proc=3)

def test_any(self):
ds_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}, {
'images': [self.img1_path, self.img3_path]
}]
tgt_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageFaceCountFilter(min_face_count=1, max_face_count=1, any_or_all='any')
self._run_face_count_filter(dataset, tgt_list, op)

def test_all(self):
ds_list = [{
'images': [self.img1_path, self.img2_path]
}, {
'images': [self.img2_path, self.img3_path]
}, {
'images': [self.img1_path, self.img3_path]
}]
tgt_list = [{
'images': [self.img2_path, self.img3_path]
}]
dataset = Dataset.from_list(ds_list)
op = ImageFaceCountFilter(min_face_count=1, max_face_count=5, any_or_all='all')
self._run_face_count_filter(dataset, tgt_list, op)


if __name__ == '__main__':
unittest.main()

0 comments on commit 539b12d

Please sign in to comment.