Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding support for multi gpu and nasnet #97

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@ data/sequences/*
data/train/*
data/test/*
data/c3d/*


.DS_Store
18 changes: 9 additions & 9 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def split_train_test(self):
test.append(item)
return train, test

def get_all_sequences_in_memory(self, train_test, data_type):
def get_all_sequences_in_memory(self, train_test, data_type,cnn_model_type):
"""
This is a mirror of our generator, but attempts to load everything into
memory so we can train way faster.
Expand All @@ -137,19 +137,19 @@ def get_all_sequences_in_memory(self, train_test, data_type):
sequence = self.build_image_sequence(frames)

else:
sequence = self.get_extracted_sequence(data_type, row)
sequence = self.get_extracted_sequence(data_type, row,cnn_model_type=cnn_model_type)

if sequence is None:
print("Can't find sequence. Did you generate them?")
raise
raise(IOError)

X.append(sequence)
y.append(self.get_class_one_hot(row[1]))

return np.array(X), np.array(y)

@threadsafe_generator
def frame_generator(self, batch_size, train_test, data_type):
def frame_generator(self, batch_size, train_test, data_type,cnn_model_type):
"""Return a generator that we can use to train on. There are
a couple different things we can return:

Expand Down Expand Up @@ -182,7 +182,7 @@ def frame_generator(self, batch_size, train_test, data_type):
sequence = self.build_image_sequence(frames)
else:
# Get the sequence from disk.
sequence = self.get_extracted_sequence(data_type, sample)
sequence = self.get_extracted_sequence(data_type, sample,cnn_model_type=cnn_model_type)

if sequence is None:
raise ValueError("Can't find sequence. Did you generate them?")
Expand All @@ -196,17 +196,17 @@ def build_image_sequence(self, frames):
"""Given a set of frames (filenames), build our sequence."""
return [process_image(x, self.image_shape) for x in frames]

def get_extracted_sequence(self, data_type, sample):
def get_extracted_sequence(self, data_type, sample,cnn_model_type):
"""Get the saved extracted features."""
filename = sample[2]
path = os.path.join(self.sequence_path, filename + '-' + str(self.seq_length) + \
path = os.path.join(self.sequence_path, filename + '-' + str(self.seq_length) + '-' + cnn_model_type + \
'-' + data_type + '.npy')
if os.path.isfile(path):
return np.load(path)
else:
return None

def get_frames_by_filename(self, filename, data_type):
def get_frames_by_filename(self, filename, data_type,cnn_model_type):
"""Given a filename for one of our samples, return the data
the model needs to make predictions."""
# First, find the sample row.
Expand All @@ -226,7 +226,7 @@ def get_frames_by_filename(self, filename, data_type):
sequence = self.build_image_sequence(frames)
else:
# Get the sequence from disk.
sequence = self.get_extracted_sequence(data_type, sample)
sequence = self.get_extracted_sequence(data_type, sample,cnn_model_type=cnn_model_type)

if sequence is None:
raise ValueError("Can't find sequence. Did you generate them?")
Expand Down
8 changes: 4 additions & 4 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from data import DataSet
import numpy as np

def predict(data_type, seq_length, saved_model, image_shape, video_name, class_limit):
def predict(data_type, seq_length, saved_model, image_shape, video_name, class_limit,cnn_model_type):
model = load_model(saved_model)

# Get the data and process it.
Expand All @@ -23,7 +23,7 @@ def predict(data_type, seq_length, saved_model, image_shape, video_name, class_l
class_limit=class_limit)

# Extract the sample from the data.
sample = data.get_frames_by_filename(video_name, data_type)
sample = data.get_frames_by_filename(video_name, data_type,cnn_model_type=cnn_model_type)

# Predict!
prediction = model.predict(np.expand_dims(sample, axis=0))
Expand All @@ -48,7 +48,7 @@ def main():
# an actual video file, extract frames, generate sequences, etc.
#video_name = 'v_Archery_g04_c02'
video_name = 'v_ApplyLipstick_g01_c01'

cnn_model_type = 'InceptionV3'
# Chose images or features and image shape based on network.
if model in ['conv_3d', 'c3d', 'lrcn']:
data_type = 'images'
Expand All @@ -59,7 +59,7 @@ def main():
else:
raise ValueError("Invalid model. See train.py for options.")

predict(data_type, seq_length, saved_model, image_shape, video_name, class_limit)
predict(data_type, seq_length, saved_model, image_shape, video_name, class_limit,cnn_model_type)

if __name__ == '__main__':
main()
17 changes: 11 additions & 6 deletions extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,22 @@
# Set defaults.
seq_length = 40
class_limit = None # Number of classes to extract. Can be 1-101 or None for all.
# cnn_model_type='InceptionV3'
cnn_model_type='nasnet'

n_gpu=8
# Get the dataset.
data = DataSet(seq_length=seq_length, class_limit=class_limit)

# get the model.
model = Extractor()
model = Extractor(cnn_model_type=cnn_model_type, n_gpu=n_gpu)

# Loop through data.
pbar = tqdm(total=len(data.data))
for video in data.data:

# Get the path to the sequence for this video.
path = os.path.join('data', 'sequences', video[2] + '-' + str(seq_length) + \
path = os.path.join('data', 'sequences', video[2] + '-' + str(seq_length) + '-' + cnn_model_type + \
'-features') # numpy will auto-append .npy

# Check if we already have it.
Expand All @@ -46,11 +49,13 @@
# Now downsample to just the ones we need.
frames = data.rescale_list(frames, seq_length)

#Batch Processing is more efficient on GPU
sequence= list(model.extract_batch(frames,cnn_model_type=cnn_model_type))
# Now loop through and extract features to build the sequence.
sequence = []
for image in frames:
features = model.extract(image)
sequence.append(features)
# sequence2 = []
# for image in frames:
# features = model.extract(image,cnn_model_type=cnn_model_type)
# sequence2.append(features)

# Save the sequence.
np.save(path, sequence)
Expand Down
88 changes: 65 additions & 23 deletions extractor.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,38 @@
from keras.preprocessing import image
from keras.applications.inception_v3 import InceptionV3, preprocess_input
from keras.applications import inception_v3, nasnet
from keras.models import Model, load_model
from keras.layers import Input
from keras.utils import multi_gpu_model

import numpy as np

class Extractor():
def __init__(self, weights=None):
def __init__(self, weights=None, cnn_model_type='nasnet', n_gpu=1):
"""Either load pretrained from imagenet, or load our saved
weights from our own training."""

self.weights = weights # so we can check elsewhere which model

if weights is None:
# Get model with pretrained weights.
base_model = InceptionV3(
weights='imagenet',
include_top=True
)

# We'll extract features at the final pool layer.
self.model = Model(
inputs=base_model.input,
outputs=base_model.get_layer('avg_pool').output
)
if cnn_model_type == 'InceptionV3':
self.model = inception_v3.InceptionV3(
weights='imagenet',pooling='avg',
include_top=False
)
elif cnn_model_type == 'nasnet':
base_model = nasnet.NASNetLarge(
weights='imagenet',
include_top=True
)
# issue https://github.com/keras-team/keras/issues/10109
self.model = Model(
inputs=base_model.input,
outputs=base_model.get_layer('global_average_pooling2d_1').output
)

else:
# Load the model first.
self.model = load_model(weights)

# Then remove the top so we get features not predictions.
# From: https://github.com/fchollet/keras/issues/2371
self.model.layers.pop()
Expand All @@ -36,20 +41,57 @@ def __init__(self, weights=None):
self.model.output_layers = [self.model.layers[-1]]
self.model.layers[-1].outbound_nodes = []

def extract(self, image_path):
img = image.load_img(image_path, target_size=(299, 299))
if n_gpu>1:
self.model = multi_gpu_model(self.model,n_gpu)

def extract(self, image_path, cnn_model_type='nasnet'):
if cnn_model_type== 'InceptionV3':
target_size = (299, 299)
elif cnn_model_type== 'nasnet':
target_size = (331, 331)

img = image.load_img(image_path, target_size=target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

if cnn_model_type == 'InceptionV3':
x = inception_v3.preprocess_input(x)
elif cnn_model_type == 'nasnet':
x = nasnet.preprocess_input(x)

# Get the prediction.
features = self.model.predict(x)

if self.weights is None:
# For imagenet/default network:
features = features[0]
else:
# For loaded network:
features = features[0]
features = features[0]

return features


def extract_batch(self, image_path_list, cnn_model_type='InceptionV3'):
if cnn_model_type== 'InceptionV3':
target_size = (299, 299,3)
# feature_size = 2048

elif cnn_model_type== 'nasnet':
target_size = (331, 331,3)
# feature_size = 4032

batch_size = len(image_path_list)

X = np.zeros((batch_size,) + target_size )

for img_idx, image_path in enumerate(image_path_list):
img = image.load_img(image_path, target_size=target_size[0:2])
array = image.img_to_array(img)
X[img_idx] = array
# x = np.expand_dims(x, axis=0)

if cnn_model_type == 'InceptionV3':
X = inception_v3.preprocess_input(X)
elif cnn_model_type == 'nasnet':
X = nasnet.preprocess_input(X)

# Get the prediction.
features_batch = self.model.predict(X)

return features_batch
69 changes: 49 additions & 20 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,34 @@
from keras.layers.wrappers import TimeDistributed
from keras.layers.convolutional import (Conv2D, MaxPooling3D, Conv3D,
MaxPooling2D)
from keras.utils import multi_gpu_model
from keras import Model

import tensorflow as tf

from collections import deque
import sys

# A wrapper class for multi gpu saving and loading
class ModelMGPU(Model):
def __init__(self, ser_model, gpus):
pmodel = multi_gpu_model(ser_model, gpus)
self.__dict__.update(pmodel.__dict__)
self._smodel = ser_model

def __getattribute__(self, attrname):
'''Override load and save methods to be used from the serial-model. The
serial-model holds references to the weights in the multi-gpu model.
'''
# return Model.__getattribute__(self, attrname)
if 'load' in attrname or 'save' in attrname:
return getattr(self._smodel, attrname)

return super(ModelMGPU, self).__getattribute__(attrname)

class ResearchModels():
def __init__(self, nb_classes, model, seq_length,
saved_model=None, features_length=2048):
def __init__(self, nb_classes, model_type, seq_length,
saved_model=None, cnn_feature_size=4032,n_gpus = 8):
"""
`model` = one of:
lstm
Expand All @@ -40,45 +62,52 @@ def __init__(self, nb_classes, model, seq_length,

# Get the appropriate model.
if self.saved_model is not None:
print("Loading model %s" % self.saved_model)
self.model = load_model(self.saved_model)
elif model == 'lstm':
with tf.device('/cpu:0'):
print("Loading model %s" % self.saved_model)
serial_model = load_model(self.saved_model)
elif model_type == 'lstm':
print("Loading LSTM model.")
self.input_shape = (seq_length, features_length)
self.model = self.lstm()
elif model == 'lrcn':
self.input_shape = (seq_length, cnn_feature_size)
serial_model = self.lstm(cnn_feature_size=cnn_feature_size)
elif model_type == 'lrcn':
print("Loading CNN-LSTM model.")
self.input_shape = (seq_length, 80, 80, 3)
self.model = self.lrcn()
elif model == 'mlp':
serial_model = self.lrcn()
elif model_type == 'mlp':
print("Loading simple MLP.")
self.input_shape = (seq_length, features_length)
self.model = self.mlp()
elif model == 'conv_3d':
self.input_shape = (seq_length, cnn_feature_size)
serial_model= self.mlp()
elif model_type == 'conv_3d':
print("Loading Conv3D")
self.input_shape = (seq_length, 80, 80, 3)
self.model = self.conv_3d()
elif model == 'c3d':
serial_model = self.conv_3d()
elif model_type == 'c3d':
print("Loading C3D")
self.input_shape = (seq_length, 80, 80, 3)
self.model = self.c3d()
serial_model = self.c3d()
else:
print("Unknown network.")
sys.exit()

if n_gpus==1:
self.model = serial_model
else:
self.model = ModelMGPU(ser_model=serial_model,gpus=n_gpus)
# Now compile the network.
optimizer = Adam(lr=1e-5, decay=1e-6)
optimizer = Adam(lr=1e-5*n_gpus, decay=1e-6)
self.model.compile(loss='categorical_crossentropy', optimizer=optimizer,
metrics=metrics)

print(self.model.summary())
print(serial_model.summary())

def lstm(self):
def lstm(self,cnn_feature_size=4032):
"""Build a simple LSTM network. We pass the extracted features from
our CNN to this model predomenently."""
# Model.
model = Sequential()
model.add(LSTM(2048, return_sequences=False,


model.add(LSTM(cnn_feature_size, return_sequences=False,
input_shape=self.input_shape,
dropout=0.5))
model.add(Dense(512, activation='relu'))
Expand Down
Loading