Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
API/ENH: Add image database step, postproc module
Browse files Browse the repository at this point in the history
This is a big one.

We add a processing step: load all images in parallel (we use
multiprocessing) into a in-memory db (a dict) and resize to the
dimensions used for the NN model. It turns out that loading many images
from disk (even < 10 MB/image, even with an SSD) is slow. With the
images in memory, the fingerprints loop actually loads all CPU cores
~100% with TensorFlow w/o waiting for IO. Functions such as
fingerprint(s) which used to get a file name or a list of those now work
with (the dict of) image arrays.

Rename imagecluster.py -> calc.py

Add postproc.py and move make_links() there. Add function to plot a grid
of images arranged into clusters. We build a numpy array out of images
and use only one imshow() call .. cool eh? We use the in-memory images
here as well. Else, plotting would be painfully slow.
  • Loading branch information
elcorto committed Feb 18, 2019
1 parent 1817c07 commit 872d45f
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 81 deletions.
142 changes: 76 additions & 66 deletions imagecluster/imagecluster.py → imagecluster/calc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import os

import multiprocessing as mp
import functools

import PIL.Image
from scipy.spatial import distance
from scipy.cluster import hierarchy
import numpy as np

import PIL.Image, os, shutil
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.preprocessing import image
from keras.models import Model

from . import common

pj = os.path.join


Expand Down Expand Up @@ -42,39 +49,45 @@ def get_model(layer='fc2'):
return model


def fingerprint(fn, model, size):
"""Load image from file `fn`, resize to `size` and run through `model`
(keras.models.Model).
# keras.preprocessing.image.load_img() uses img.rezize(shape) with the default
# interpolation of PIL.Image.resize() which is pretty bad (see
# imagecluster/play/pil_resample_methods.py). Given that we are restricted to
# small inputs of 224x224 by the VGG network, we should do our best to keep as
# much information from the original image as possible. This is a gut feeling,
# untested. But given that model.predict() is 10x slower than PIL image loading
# and resizing .. who cares.
#
# (224, 224, 3)
##img = image.load_img(fn, target_size=size)
##... = image.img_to_array(img)
def _img_worker(fn, size):
print(fn)
return fn, image.img_to_array(PIL.Image.open(fn).resize(size, 3),
dtype=int)


def image_arrays(imagedir, size):
_f = functools.partial(_img_worker, size=size)
with mp.Pool(mp.cpu_count()) as pool:
ret = pool.map(_f, common.get_files(imagedir))
return dict(ret)


def fingerprint(img_arr, model):
"""Run image array (3d array) run through `model` (keras.models.Model).
Parameters
----------
fn : str
filename
img_arr : 3d array
(1,x,y) or (x,y,1), depending on keras.preprocessing.image.img_to_array
and "jq '.image_data_format' ~/.keras/keras.json"
(channels_{first,last}), see :func:`imagecluster.main.image_arrays`
model : keras.models.Model instance
size : tuple
input image size (width, height), must match `model`, e.g. (224,224)
Returns
-------
fingerprint : 1d array
"""
print(fn)

# keras.preprocessing.image.load_img() uses img.rezize(shape) with the
# default interpolation of PIL.Image.resize() which is pretty bad (see
# imagecluster/play/pil_resample_methods.py). Given that we are restricted
# to small inputs of 224x224 by the VGG network, we should do our best to
# keep as much information from the original image as possible. This is a
# gut feeling, untested. But given that model.predict() is 10x slower than
# PIL image loading and resizing .. who cares.
#
# (224, 224, 3)
##img = image.load_img(fn, target_size=size)
img = PIL.Image.open(fn).resize(size, 3)

# (224, 224, {3,1})
arr3d = image.img_to_array(img)

# (224, 224, 1) -> (224, 224, 3)
#
# Simple hack to convert a grayscale image to fake RGB by replication of
Expand All @@ -85,47 +98,47 @@ def fingerprint(fn, model, size):
# the image representation than color, such that this hack makes it possible
# to process gray-scale images with nets trained on color images (like
# VGG16).
if arr3d.shape[2] == 1:
arr3d = arr3d.repeat(3, axis=2)
#
# We assme channels_last here. Fix if needed.
if img_arr.shape[2] == 1:
img_arr = img_arr.repeat(3, axis=2)

# (1, 224, 224, 3)
arr4d = np.expand_dims(arr3d, axis=0)
arr4d = np.expand_dims(img_arr, axis=0)

# (1, 224, 224, 3)
arr4d_pp = preprocess_input(arr4d)
return model.predict(arr4d_pp)[0,:]


# Cannot use multiprocessing (only tensorflow backend tested):
# TypeError: can't pickle _thread.lock objects
# The error doesn't come from functools.partial since those objects are
# pickable since python3. The reason is the keras.model.Model, which is not
# pickable. However keras with tensorflow backend runs multi-threaded
# (model.predict()), so we don't need that. I guess it will scale better if we
# parallelize over images than to run a muti-threaded tensorflow on each image,
# but OK. On low core counts (2-4), it won't matter.
# Cannot use multiprocessing (only tensorflow backend tested, rumor has it that
# the TF computation graph is not built multiple times, i.e. pickling (what
# multiprocessing does with _worker) doen't play nice with Keras models which
# use Tf backend). The call to the parallel version of fingerprints() starts
# but seems to hang forever. However, Keras with Tensorflow backend runs
# multi-threaded (model.predict()), so we can sort of live with that. Even
# though Tensorflow has not the best scaling on the CPU, on low core counts
# (2-4), it won't matter that much. Also, TF was built to run on GPUs, not
# scale out multi-core CPUs.
#
##def _worker(fn, model, size):
##def _worker(img_arr, model):
## print(fn)
## return fn, fingerprint(fn, model, size)
## return fn, fingerprint(img_arr, model)
##
##import functools, multiprocessing
##def fingerprints(files, model, size=(224,224)):
## worker = functools.partial(_worker,
## model=model,
## size=size)
## pool = multiprocessing.Pool(multiprocessing.cpu_count()/2)
## return dict(pool.map(worker, files))

##
##def fingerprints(ias, model):
## _f = functools.partial(_worker, model=model)
## with mp.Pool(int(mp.cpu_count()/2)) as pool:
## ret = pool.map(_f, ias.items())
## return dict(ret)

def fingerprints(files, model, size=(224,224)):
"""Calculate fingerprints for all `files`.
def fingerprints(ias, model):
"""Calculate fingerprints for all image arrays in `ias`.
Parameters
----------
files : sequence
image filenames
model, size : see :func:`fingerprint`
ias : see :func:`image_arrays`
model : see :func:`fingerprint`
Returns
-------
Expand All @@ -135,7 +148,11 @@ def fingerprints(files, model, size=(224,224)):
...
}
"""
return dict((fn, fingerprint(fn, model, size)) for fn in files)
fps = {}
for fn,img_arr in ias.items():
print(fn)
fps[fn] = fingerprint(img_arr, model)
return fps


def cluster(fps, sim=0.5, method='average', metric='euclidean',
Expand Down Expand Up @@ -163,7 +180,7 @@ def cluster(fps, sim=0.5, method='average', metric='euclidean',
clusters [, extra]
clusters : dict
We call a list of file names a "cluster".
keys = size of clusters (number of elements (images))
keys = size of clusters (number of elements (images) `nelem`)
value = list of clusters with that size
{nelem : [[filename, filename, ...],
[filename, filename, ...],
Expand Down Expand Up @@ -220,18 +237,11 @@ def print_cluster_stats(clusters):
stats = cluster_stats(clusters)
for nelem in np.sort(list(stats.keys())):
print("{} : {}".format(nelem, stats[nelem]))
if len(stats) > 0:
nimg = np.array(list(stats.items())).prod(axis=1).sum()
else:
nimg = 0
print("#images in clusters total: ", nimg)



def make_links(clusters, cluster_dr):
print("cluster dir: {}".format(cluster_dr))
if os.path.exists(cluster_dr):
shutil.rmtree(cluster_dr)
for nelem, group in clusters.items():
for iclus, cluster in enumerate(group):
dr = pj(cluster_dr,
'cluster_with_{}'.format(nelem),
'cluster_{}'.format(iclus))
for fn in cluster:
link = pj(dr, os.path.basename(fn))
os.makedirs(os.path.dirname(link), exist_ok=True)
os.symlink(os.path.abspath(fn), link)
70 changes: 55 additions & 15 deletions imagecluster/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,28 @@
import os, re
import numpy as np
from imagecluster import imagecluster as ic
import os

from imagecluster import calc as ic
from imagecluster import common as co
from imagecluster import postproc as pp

pj = os.path.join


ic_base_dir = 'imagecluster'


def main(imagedir, sim=0.5, layer='fc2'):
def main(imagedir, sim=0.5, layer='fc2', size=(224,224), links=True, vis=False,
maxelem=None):
"""Example main app using this library.
Upon first invocation, the image and fingerprint databases are built and
written to disk. Each new invocation loads those and only repeats
* clustering
* creation of links to files in clusters
* visualization (if `vis=True`)
This is good for playing around with the `sim` parameter, for
instance, which only influences clustering.
Parameters
----------
imagedir : str
Expand All @@ -20,18 +32,46 @@ def main(imagedir, sim=0.5, layer='fc2'):
layer : str
which layer to use as feature vector (see
:func:`imagecluster.get_model`)
size : tuple
input image size (width, height), must match `model`, e.g. (224,224)
links : bool
create dirs with links
vis : bool
plot images in clusters
maxelem : max number of images per cluster for visualization (see
:mod:`~postproc`)
Notes
-----
imagedir : To select only a subset of the images, create an `imagedir` and
symlink your selected images there. In the future, we may add support
for passing a list of files, should the need arise. But then again,
this function is only an example front-end.
"""
dbfn = pj(imagedir, ic_base_dir, 'fingerprints.pk')
if not os.path.exists(dbfn):
os.makedirs(os.path.dirname(dbfn), exist_ok=True)
print("no fingerprints database {} found".format(dbfn))
files = co.get_files(imagedir)
fps_fn = pj(imagedir, ic_base_dir, 'fingerprints.pk')
ias_fn = pj(imagedir, ic_base_dir, 'images.pk')
ias = None
if not os.path.exists(fps_fn):
print(f"no fingerprints database {fps_fn} found")
os.makedirs(os.path.dirname(fps_fn), exist_ok=True)
model = ic.get_model(layer=layer)
print("running all images through NN model ...".format(dbfn))
fps = ic.fingerprints(files, model, size=(224,224))
co.write_pk(fps, dbfn)
if not os.path.exists(ias_fn):
print(f"create image array database {ias_fn}")
ias = ic.image_arrays(imagedir, size=size)
co.write_pk(ias, ias_fn)
else:
ias = co.read_pk(ias_fn)
print("running all images through NN model ...")
fps = ic.fingerprints(ias, model)
co.write_pk(fps, fps_fn)
else:
print("loading fingerprints database {} ...".format(dbfn))
fps = co.read_pk(dbfn)
print(f"loading fingerprints database {fps_fn} ...")
fps = co.read_pk(fps_fn)
print("clustering ...")
ic.make_links(ic.cluster(fps, sim), pj(imagedir, ic_base_dir, 'clusters'))
clusters = ic.cluster(fps, sim)
if links:
pp.make_links(clusters, pj(imagedir, ic_base_dir, 'clusters'))
if vis:
if ias is None:
ias = co.read_pk(ias_fn)
pp.visualize(clusters, ias, maxelem=maxelem)
62 changes: 62 additions & 0 deletions imagecluster/postproc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
import shutil

from matplotlib import pyplot as plt
import numpy as np

from . import calc as ic

pj = os.path.join


def plot_clusters(clusters, ias, maxelem=None):
"""Plot `clusters` of images in `ias`.
For interactive work, use :func:`visualize` instead.
Parameters
----------
clusters : see :func:`imagecluster.cluster`
ias : see :func:`imagecluster.image_arrays`
"""
stats = ic.cluster_stats(clusters)
ncols = sum(list(stats.values()))
nrows = max(stats.keys())
if maxelem is not None:
nrows = min(maxelem, nrows)
shape = ias[list(ias.keys())[0]].shape[:2]
arr = np.ones((nrows*shape[0], ncols*shape[1], 3), dtype=int) * 255
icol = -1
for nelem in np.sort(list(clusters.keys())):
for cluster in clusters[nelem]:
icol += 1
for irow, filename in enumerate(cluster[:nrows]):
img_arr = ias[filename]
arr[irow*shape[0]:(irow+1)*shape[0], icol*shape[1]:(icol+1)*shape[1], :] = img_arr
fig_scale = 1/shape[0]
figsize = np.array(arr.shape[:2][::-1])*fig_scale
fig,ax = plt.subplots(figsize=figsize)
ax.imshow(arr)
ax.axis('off')
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)
return fig,ax


def visualize(*args, **kwds):
plot_clusters(*args, **kwds)
plt.show()


def make_links(clusters, cluster_dr):
print("cluster dir: {}".format(cluster_dr))
if os.path.exists(cluster_dr):
shutil.rmtree(cluster_dr)
for nelem, group in clusters.items():
for iclus, cluster in enumerate(group):
dr = pj(cluster_dr,
'cluster_with_{}'.format(nelem),
'cluster_{}'.format(iclus))
for fn in cluster:
link = pj(dr, os.path.basename(fn))
os.makedirs(os.path.dirname(link), exist_ok=True)
os.symlink(os.path.abspath(fn), link)

0 comments on commit 872d45f

Please sign in to comment.