Skip to content
This repository has been archived by the owner on Dec 13, 2024. It is now read-only.

Commit

Permalink
Add user-facing option to select feature layer
Browse files Browse the repository at this point in the history
  • Loading branch information
elcorto committed Feb 8, 2019
1 parent c045ed0 commit 0ec5fa9
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
17 changes: 13 additions & 4 deletions imagecluster/imagecluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,18 @@
pj = os.path.join


def get_model():
"""Keras Model of the VGG16 network, with the output layer set to the
second-to-last fully connected layer 'fc2' of shape (4096,)."""
def get_model(layer='fc2'):
"""Keras Model of the VGG16 network, with the output layer set to `layer`.
The default layer is the second-to-last fully connected layer 'fc2' of
shape (4096,).
Parameters
----------
layer : str
which layer to extract (must be of shape (None, X)), e.g. 'fc2', 'fc1'
or 'flatten'
"""
# base_model.summary():
# ....
# block5_conv4 (Conv2D) (None, 15, 15, 512) 2359808
Expand All @@ -32,7 +41,7 @@ def get_model():
#
base_model = VGG16(weights='imagenet', include_top=True)
model = Model(inputs=base_model.input,
outputs=base_model.get_layer('fc2').output)
outputs=base_model.get_layer(layer).output)
return model


Expand Down
11 changes: 7 additions & 4 deletions imagecluster/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,25 @@
ic_base_dir = 'imagecluster'


def main(imagedir, sim=0.5):
def main(imagedir, sim=0.5, layer='fc2'):
"""Example main app using this library.
Parameters
----------
imagedir : str
path to directory with images
sim : float (0..1)
similarity index (see imagecluster.cluster())
similarity index (see :func:`imagecluster.cluster`)
layer : str
which layer to use as feature vector (see
:func:`imagecluster.get_model`)
"""
dbfn = pj(imagedir, ic_base_dir, 'fingerprints.pk')
if not os.path.exists(dbfn):
os.makedirs(os.path.dirname(dbfn), exist_ok=True)
print("no fingerprints database {} found".format(dbfn))
files = co.get_files(imagedir)
model = ic.get_model()
model = ic.get_model(layer=layer)
print("running all images through NN model ...".format(dbfn))
fps = ic.fingerprints(files, model, size=(224,224))
co.write_pk(fps, dbfn)
Expand Down

0 comments on commit 0ec5fa9

Please sign in to comment.