From 9477900946f923cb43ed76ed215490d01474bfe7 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Tue, 4 Apr 2017 15:51:13 -0800 Subject: [PATCH] Backport fixes and improvements from external Keras. Change: 152198296 --- .../contrib/keras/python/keras/__init__.py | 2 +- .../contrib/keras/python/keras/activations.py | 24 ++- .../python/keras/applications/resnet50.py | 4 +- .../contrib/keras/python/keras/backend.py | 94 +++++++---- .../keras/python/keras/engine/topology.py | 30 +++- .../keras/python/keras/engine/training.py | 26 ++- .../keras/python/keras/initializers.py | 9 +- .../python/keras/layers/convolutional.py | 36 ++-- .../keras/layers/convolutional_recurrent.py | 2 +- .../contrib/keras/python/keras/layers/core.py | 6 +- .../keras/python/keras/layers/local.py | 16 +- .../keras/python/keras/layers/merge.py | 156 ++++++++++++++++-- .../python/keras/layers/normalization.py | 2 +- .../keras/python/keras/layers/pooling.py | 16 +- .../keras/python/keras/layers/recurrent.py | 26 ++- .../keras/python/keras/layers/wrappers.py | 50 ++++-- .../contrib/keras/python/keras/metrics.py | 9 +- .../contrib/keras/python/keras/models.py | 23 ++- .../keras/python/keras/preprocessing/image.py | 2 +- .../keras/python/keras/utils/generic_utils.py | 5 +- .../keras/python/keras/utils/layer_utils.py | 2 +- .../python/keras/wrappers/scikit_learn.py | 34 +++- 22 files changed, 424 insertions(+), 150 deletions(-) diff --git a/tensorflow/contrib/keras/python/keras/__init__.py b/tensorflow/contrib/keras/python/keras/__init__.py index cdfc40dff1dcaf..ec316253dbacb9 100644 --- a/tensorflow/contrib/keras/python/keras/__init__.py +++ b/tensorflow/contrib/keras/python/keras/__init__.py @@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import wrappers -__version__ = '2.0.0-tf' +__version__ = '2.0.2-tf' diff --git a/tensorflow/contrib/keras/python/keras/activations.py b/tensorflow/contrib/keras/python/keras/activations.py index 1eac52dfad6b57..67762c83ba2960 100644 --- a/tensorflow/contrib/keras/python/keras/activations.py +++ b/tensorflow/contrib/keras/python/keras/activations.py @@ -24,18 +24,28 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object -def softmax(x): +def softmax(x, axis=-1): + """Softmax activation function. + + Arguments: + x : Tensor. + axis: Integer, axis along which the softmax normalization is applied. + + Returns: + Tensor, output of softmax transformation. + + Raises: + ValueError: In case `dim(x) == 1`. + """ ndim = K.ndim(x) if ndim == 2: return K.softmax(x) - elif ndim == 3: - e = K.exp(x - K.max(x, axis=-1, keepdims=True)) - s = K.sum(e, axis=-1, keepdims=True) + elif ndim > 2: + e = K.exp(x - K.max(x, axis=axis, keepdims=True)) + s = K.sum(e, axis=axis, keepdims=True) return e / s else: - raise ValueError('Cannot apply softmax to a tensor ' - 'that is not 2D or 3D. ' - 'Here, ndim=' + str(ndim)) + raise ValueError('Cannot apply softmax to a tensor that is 1D') def elu(x, alpha=1.0): diff --git a/tensorflow/contrib/keras/python/keras/applications/resnet50.py b/tensorflow/contrib/keras/python/keras/applications/resnet50.py index 546fcb9433abdc..12f7ca424edb85 100644 --- a/tensorflow/contrib/keras/python/keras/applications/resnet50.py +++ b/tensorflow/contrib/keras/python/keras/applications/resnet50.py @@ -163,8 +163,8 @@ def ResNet50(include_top=True, specified in your Keras config file. Arguments: - include_top: whether to include the 3 fully-connected - layers at the top of the network. + include_top: whether to include the fully-connected + layer at the top of the network. weights: one of `None` (random initialization) or "imagenet" (pre-training on ImageNet). input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index 9769bce3b059f9..d7c646c19a79b2 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -22,7 +22,6 @@ from __future__ import print_function from collections import defaultdict -import errno import json import os import warnings @@ -270,6 +269,7 @@ def clear_session(): reset_uids() _SESSION = None phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase') + _GRAPH_LEARNING_PHASES = {} _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase @@ -1257,6 +1257,34 @@ def prod(x, axis=None, keepdims=False): return math_ops.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims) +def cumsum(x, axis=0): + """Cumulative sum of the values in a tensor, alongside the specified axis. + + Arguments: + x: A tensor or variable. + axis: An integer, the axis to compute the sum. + + Returns: + A tensor of the cumulative sum of values of `x` along `axis`. + """ + axis = _normalize_axis(axis, ndim(x)) + return math_ops.cumsum(x, axis=axis) + + +def cumprod(x, axis=0): + """Cumulative product of the values in a tensor, alongside the specified axis. + + Arguments: + x: A tensor or variable. + axis: An integer, the axis to compute the product. + + Returns: + A tensor of the cumulative product of values of `x` along `axis`. + """ + axis = _normalize_axis(axis, ndim(x)) + return math_ops.cumprod(x, axis=axis) + + def var(x, axis=None, keepdims=False): """Variance of a tensor, alongside the specified axis. @@ -1330,8 +1358,7 @@ def any(x, axis=None, keepdims=False): """ axis = _normalize_axis(axis, ndim(x)) x = math_ops.cast(x, dtypes_module.bool) - x = math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims) - return math_ops.cast(x, dtypes_module.uint8) + return math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims) def all(x, axis=None, keepdims=False): @@ -1347,8 +1374,7 @@ def all(x, axis=None, keepdims=False): """ axis = _normalize_axis(axis, ndim(x)) x = math_ops.cast(x, dtypes_module.bool) - x = math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims) - return math_ops.cast(x, dtypes_module.uint8) + return math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims) def argmax(x, axis=-1): @@ -1645,7 +1671,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3): """ mean, var = nn.moments( x, reduction_axes, shift=None, name=None, keep_dims=False) - if sorted(reduction_axes) == range(ndim(x))[:-1]: + if sorted(reduction_axes) == list(range(ndim(x)))[:-1]: normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon) else: # need broadcasting @@ -2324,8 +2350,8 @@ def rnn(step_function, (no time dimension), containing the initial values for the states used in the step function. - go_backwards: boolean. If True, do the iteration over - the time dimension in reverse order. + go_backwards: boolean. If True, do the iteration over the time + dimension in reverse order and return the reversed sequence. mask: binary tensor with shape `(samples, time, 1)`, with a zero for every element that is masked. constants: a list of constant values passed at each step. @@ -2414,9 +2440,9 @@ def rnn(step_function, states = return_states successive_outputs.append(output) successive_states.append(states) - last_output = successive_outputs[-1] - new_states = successive_states[-1] - outputs = array_ops.stack(successive_outputs) + last_output = successive_outputs[-1] + new_states = successive_states[-1] + outputs = array_ops.stack(successive_outputs) else: for inp in input_list: output, states = step_function(inp, states + constants) @@ -3534,19 +3560,19 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): # HIGH ORDER FUNCTIONS -def map_fn(fn, elems, name=None): +def map_fn(fn, elems, name=None, dtype=None): """Map the function fn over the elements elems and return the outputs. Arguments: fn: Callable that will be called upon each element in elems elems: tensor name: A string name for the map node in the graph + dtype: Output data type. Returns: - Tensor with first dimension equal to the elems and second depending on - fn + Tensor with dtype `dtype`. """ - return functional_ops.map_fn(fn, elems, name=name) + return functional_ops.map_fn(fn, elems, name=name, dtype=dtype) def foldl(fn, elems, initializer=None, name=None): @@ -3560,7 +3586,7 @@ def foldl(fn, elems, initializer=None, name=None): name: A string name for the foldl node in the graph Returns: - Same type and shape as initializer + Tensor with same type and shape as `initializer`. """ return functional_ops.foldl(fn, elems, initializer=initializer, name=name) @@ -3583,27 +3609,39 @@ def foldr(fn, elems, initializer=None, name=None): # Load Keras default configuration from config file if present. _keras_base_dir = os.path.expanduser('~') -if not os.access(_keras_base_dir, os.W_OK): - _keras_base_dir = '/tmp' _keras_dir = os.path.join(_keras_base_dir, '.keras') -if not os.path.exists(_keras_dir): - try: - os.makedirs(_keras_dir) - except OSError as e: - if e.errno == errno.EEXIST: - pass - else: - raise _config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json')) if os.path.exists(_config_path): - _config = json.load(open(_config_path)) + try: + _config = json.load(open(_config_path)) + except json.decoder.JSONDecodeError: + _config = {} _floatx = _config.get('floatx', floatx()) assert _floatx in {'float16', 'float32', 'float64'} _epsilon = _config.get('epsilon', epsilon()) assert isinstance(_epsilon, float) - _backend = backend() _image_data_format = _config.get('image_data_format', image_data_format()) assert _image_data_format in {'channels_last', 'channels_first'} set_floatx(_floatx) set_epsilon(_epsilon) set_image_data_format(_image_data_format) + +# Save config file. +if os.access(_keras_base_dir, os.W_OK): + if not os.path.exists(_keras_dir): + try: + os.makedirs(_keras_dir) + except OSError: + # Except potential race conditions + # in multi-threaded environments. + pass + + if not os.path.exists(_config_path): + _config = { + 'floatx': floatx(), + 'epsilon': epsilon(), + 'backend': 'tensorflow', + 'image_data_format': image_data_format() + } + with open(_config_path, 'w') as f: + f.write(json.dumps(_config, indent=4)) diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 0f506ff0a46b35..e33268235f0387 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -295,8 +295,14 @@ def __init__(self, **kwargs): # are only applicable to input layers: do not pass these keywords # to non-input layers. allowed_kwargs = { - 'input_shape', 'batch_input_shape', 'batch_size', 'dtype', 'name', - 'trainable', 'weights' + 'input_shape', + 'batch_input_shape', + 'batch_size', + 'dtype', + 'name', + 'trainable', + 'weights', + 'input_dtype', # legacy } for kwarg in kwargs: if kwarg not in allowed_kwargs: @@ -320,8 +326,15 @@ def __init__(self, **kwargs): batch_size = None batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) self.batch_input_shape = batch_input_shape - dtype = kwargs.get('dtype', K.floatx()) + + # Set dtype. + dtype = kwargs.get('dtype') + if dtype is None: + dtype = kwargs.get('input_dtype') + if dtype is None: + dtype = K.floatx() self.dtype = dtype + if 'weights' in kwargs: self._initial_weights = kwargs['weights'] else: @@ -485,11 +498,12 @@ def assert_input_compatibility(self, inputs): ': expected shape=' + str(spec.shape) + ', found shape=' + str(x_shape)) - def call(self, inputs): + def call(self, inputs, **kwargs): # pylint: disable=unused-argument """This is where the layer's logic lives. Arguments: - inputs: input tensor, or list/tuple of input tensors. + inputs: Input tensor, or list/tuple of input tensors. + **kwargs: Additional keyword arguments. Returns: A tensor or list/tuple of tensors. @@ -518,6 +532,8 @@ def __call__(self, inputs, **kwargs): ValueError: in case the layer is missing shape information for its `build` call. """ + if isinstance(inputs, list): + inputs = inputs[:] with K.name_scope(self.name): # Handle laying building (weight creating, input spec locking). if not self.built: @@ -1417,7 +1433,7 @@ class Container(Layer): get_weights set_weights get_config - get_output_shape_for + compute_output_shape # Class Methods from_config @@ -2029,7 +2045,7 @@ def _compute_output_shape(self, input_shape): for i in range(len(input_shapes)): layer = self.input_layers[i] input_shape = input_shapes[i] - # It's an input layer: get_output_shape_for is identity, + # It's an input layer: compute_output_shape is identity, # and there is only one node and one tensor output. shape_key = layer.name + '_0_0' layers_to_output_shapes[shape_key] = input_shape diff --git a/tensorflow/contrib/keras/python/keras/engine/training.py b/tensorflow/contrib/keras/python/keras/engine/training.py index efd437f6f66faf..0097c4a1c2c1bd 100644 --- a/tensorflow/contrib/keras/python/keras/engine/training.py +++ b/tensorflow/contrib/keras/python/keras/engine/training.py @@ -733,11 +733,12 @@ def compile(self, loss_functions = [] for name in self.output_names: if name not in loss: - warnings.warn('Output "' + name + '" missing from loss dictionary. ' - 'We assume this was done on purpose, ' - 'and we will not be expecting ' - 'any data to be passed to "' + name + - '" during training.') + warnings.warn( + 'Output "' + name + '" missing from loss dictionary. ' + 'We assume this was done on purpose, ' + 'and we will not be expecting ' + 'any data to be passed to "' + name + '" during training.', + stacklevel=2) loss_functions.append(losses.get(loss.get(name))) elif isinstance(loss, list): if len(loss) != len(self.outputs): @@ -1202,7 +1203,7 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0): if batch_index == 0: for batch_out in batch_outs: shape = (samples,) + batch_out.shape[1:] - outs.append(np.zeros(shape, dtype=K.floatx())) + outs.append(np.zeros(shape, dtype=batch_out.dtype)) for i, batch_out in enumerate(batch_outs): outs[i][batch_start:batch_end] = batch_out @@ -1718,7 +1719,7 @@ def fit_generator(self, - a tuple (inputs, targets, sample_weights). All arrays should contain the same number of samples. The generator is expected to loop over its data - indefinitely. An epoch finishes when `samples_per_epoch` + indefinitely. An epoch finishes when `steps_per_epoch` samples have been seen by the model. steps_per_epoch: Total number of steps (batches of samples) to yield from `generator` before declaring one epoch @@ -1767,7 +1768,7 @@ def generate_arrays_from_file(path): f.close() model.fit_generator(generate_arrays_from_file('/my_file.txt'), - samples_per_epoch=10000, epochs=10) + steps_per_epoch=10000, epochs=10) ``` Raises: @@ -2028,7 +2029,8 @@ def predict_generator(self, steps, max_q_size=10, workers=1, - pickle_safe=False): + pickle_safe=False, + verbose=0): """Generates predictions for the input samples from a data generator. The generator should return the same kind of data as accepted by @@ -2048,6 +2050,7 @@ def predict_generator(self, non picklable arguments to the generator as they can't be passed easily to children processes. + verbose: verbosity mode, 0 or 1. Returns: Numpy array(s) of predictions. @@ -2067,6 +2070,9 @@ def predict_generator(self, enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe) enqueuer.start(workers=workers, max_q_size=max_q_size) + if verbose == 1: + progbar = Progbar(target=steps) + while steps_done < steps: generator_output = None while enqueuer.is_running(): @@ -2103,6 +2109,8 @@ def predict_generator(self, for i, out in enumerate(outs): all_outs[i].append(out) steps_done += 1 + if verbose == 1: + progbar.update(steps_done) finally: if enqueuer is not None: diff --git a/tensorflow/contrib/keras/python/keras/initializers.py b/tensorflow/contrib/keras/python/keras/initializers.py index 621069f424bd63..f9cb35e171e69c 100644 --- a/tensorflow/contrib/keras/python/keras/initializers.py +++ b/tensorflow/contrib/keras/python/keras/initializers.py @@ -45,14 +45,16 @@ def from_config(cls, config): class Zeros(Initializer): - """Initializer that generates tensors initialized to 0.""" + """Initializer that generates tensors initialized to 0. + """ def __call__(self, shape, dtype=None): return K.constant(0, shape=shape, dtype=dtype) class Ones(Initializer): - """Initializer that generates tensors initialized to 1.""" + """Initializer that generates tensors initialized to 1. + """ def __call__(self, shape, dtype=None): return K.constant(1, shape=shape, dtype=dtype) @@ -130,7 +132,7 @@ def get_config(self): class TruncatedNormal(Initializer): """Initializer that generates a truncated normal distribution. - These values are similar to values from a `random_normal_initializer` + These values are similar to values from a `RandomNormal` except that values more than two standard deviations from the mean are discarded and re-drawn. This is the recommended initializer for neural network weights and filters. @@ -161,6 +163,7 @@ class VarianceScaling(Initializer): With `distribution="normal"`, samples are drawn from a truncated normal distribution centered on zero, with `stddev = sqrt(scale / n)` where n is: + - number of input units in the weight tensor, if mode = "fan_in" - number of output units, if mode = "fan_out" - average of the numbers of input and output units, if mode = "fan_avg" diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional.py b/tensorflow/contrib/keras/python/keras/layers/convolutional.py index 1a28399a28fb1d..3b68022115a208 100644 --- a/tensorflow/contrib/keras/python/keras/layers/convolutional.py +++ b/tensorflow/contrib/keras/python/keras/layers/convolutional.py @@ -244,7 +244,7 @@ def get_config(self): 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': - initializers.serialize(self.kernel_initializer), + initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': @@ -289,7 +289,7 @@ class Conv1D(_Conv): any `dilation_rate` value != 1. padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive). `"causal"` results in causal (dilated) convolutions, e.g. output[t] - depends solely on input[:t-1]. Useful when modeling temporal data + does not depend on input[t+1:]. Useful when modeling temporal data where the model should not violate the temporal order. See [WaveNet: A Generative Model for Raw Audio, section 2.1](https://arxiv.org/abs/1609.03499). @@ -395,9 +395,9 @@ class Conv2D(_Conv): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -621,7 +621,7 @@ class Conv2DTranspose(Conv2D): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for @@ -637,9 +637,9 @@ class Conv2DTranspose(Conv2D): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -688,7 +688,7 @@ def __init__(self, kernel_size, strides=(1, 1), padding='valid', - data_format='channels_last', + data_format=None, activation=None, use_bias=True, kernel_initializer='glorot_uniform', @@ -845,9 +845,9 @@ class SeparableConv2D(Conv2D): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -1079,9 +1079,9 @@ class UpSampling2D(Layer): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -1257,7 +1257,7 @@ class ZeroPadding2D(Layer): - If tuple of 2 ints: interpreted as two different symmetric padding values for height and width: - `(symmetric_height_pad, symmetrc_width_pad)`. + `(symmetric_height_pad, symmetric_width_pad)`. - If tuple of 2 tuples of 2 ints: interpreted as `((top_pad, bottom_pad), (left_pad, right_pad))` @@ -1265,9 +1265,9 @@ class ZeroPadding2D(Layer): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -1498,7 +1498,7 @@ class Cropping2D(Layer): - If tuple of 2 ints: interpreted as two different symmetric cropping values for height and width: - `(symmetric_height_crop, symmetrc_width_crop)`. + `(symmetric_height_crop, symmetric_width_crop)`. - If tuple of 2 tuples of 2 ints: interpreted as `((top_crop, bottom_crop), (left_crop, right_crop))` @@ -1506,9 +1506,9 @@ class Cropping2D(Layer): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py index 4ed5046dc310a2..4d8ef44da7bbd2 100644 --- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py @@ -357,7 +357,7 @@ def build(self, input_shape): self.states = [None, None] if self.data_format == 'channels_first': - channel_axis = 1 + channel_axis = 2 else: channel_axis = -1 if input_shape[channel_axis] is None: diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py index 1207cc119f20f2..8dd55aaa2e6f99 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core.py +++ b/tensorflow/contrib/keras/python/keras/layers/core.py @@ -88,7 +88,7 @@ class Dropout(Layer): """Applies Dropout to the input. Dropout consists in randomly setting - a fraction `p` of input units to 0 at each update during training time, + a fraction `rate` of input units to 0 at each update during training time, which helps prevent overfitting. Arguments: @@ -140,7 +140,7 @@ class SpatialDropout1D(Dropout): between feature maps and should be used instead. Arguments: - p: float between 0 and 1. Fraction of the input units to drop. + rate: float between 0 and 1. Fraction of the input units to drop. Input shape: 3D tensor with shape: @@ -775,7 +775,7 @@ def get_config(self): 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': - initializers.serialize(self.kernel_initializer), + initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': diff --git a/tensorflow/contrib/keras/python/keras/layers/local.py b/tensorflow/contrib/keras/python/keras/layers/local.py index 3bf5ee4f0fcde0..895d6e3727c248 100644 --- a/tensorflow/contrib/keras/python/keras/layers/local.py +++ b/tensorflow/contrib/keras/python/keras/layers/local.py @@ -59,7 +59,8 @@ class LocallyConnected1D(Layer): specifying the stride length of the convolution. Specifying any stride value != 1 is incompatible with specifying any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). + padding: Currently only supports `"valid"` (case-insensitive). + `"same"` may be supported in the future. activation: Activation function to use. If you don't specify anything, no activation is applied (ie. "linear" activation: `a(x) = x`). @@ -188,7 +189,7 @@ def get_config(self): 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': - initializers.serialize(self.kernel_initializer), + initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': @@ -239,16 +240,15 @@ class LocallyConnected2D(Layer): specifying the strides of the convolution along the width and height. Can be a single integer to specify the same value for all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: one of `"valid"` or `"same"` (case-insensitive). + padding: Currently only support `"valid"` (case-insensitive). + `"same"` will be supported in future. data_format: A string, one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -460,7 +460,7 @@ def get_config(self): 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'bias_initializer': - initializers.serialize(self.kernel_initializer), + initializers.serialize(self.bias_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'bias_regularizer': diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py index eea4313d31c5f5..d52bd2bbb3d672 100644 --- a/tensorflow/contrib/keras/python/keras/layers/merge.py +++ b/tensorflow/contrib/keras/python/keras/layers/merge.py @@ -41,6 +41,44 @@ def __init__(self, **kwargs): def _merge_function(self, inputs): raise NotImplementedError + def _compute_elemwise_op_output_shape(self, shape1, shape2): + """Computes the shape of the resultant of an elementwise operation. + + Arguments: + shape1: tuple or None. Shape of the first tensor + shape2: tuple or None. Shape of the second tensor + + Returns: + expected output shape when an element-wise operation is + carried out on 2 tensors with shapes shape1 and shape2. + tuple or None. + + Raises: + ValueError: if shape1 and shape2 are not compatible for + element-wise operations. + """ + if None in [shape1, shape2]: + return None + elif len(shape1) < len(shape2): + return self._compute_elemwise_op_output_shape(shape2, shape1) + elif not shape2: + return shape1 + output_shape = list(shape1[:-len(shape2)]) + for i, j in zip(shape1[-len(shape2):], shape2): + if i is None or j is None: + output_shape.append(None) + elif i == 1: + output_shape.append(j) + elif j == 1: + output_shape.append(i) + else: + if i != j: + raise ValueError('Operands could not be broadcast ' + 'together with shapes ' + str(shape1) + ' ' + str( + shape2)) + output_shape.append(i) + return tuple(output_shape) + def build(self, input_shape): # Used purely for shape validation. if not isinstance(input_shape, list): @@ -49,23 +87,107 @@ def build(self, input_shape): raise ValueError('A merge layer should be called ' 'on a list of at least 2 inputs. ' 'Got ' + str(len(input_shape)) + ' inputs.') - if all([shape is None for shape in input_shape]): - return - input_shapes = [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in input_shape - ] - # TODO(fchollet): handle shapes with None entries. - input_shapes_set = set(input_shapes) - if None in input_shapes_set: - input_shapes_set.remove(None) - if len(input_shapes_set) > 1: - raise ValueError('Only tensors of same shape can ' - 'be merged by layer' + self.name + - ' Got input shapes: %s' % input_shapes) + batch_sizes = [s[0] for s in input_shape if s is not None] + batch_sizes = set(batch_sizes) + batch_sizes -= set([None]) + if len(batch_sizes) > 1: + raise ValueError('Can not merge tensors with different ' + 'batch sizes. Got tensors with shapes : ' + str( + input_shape)) + if input_shape[0] is None: + output_shape = None + else: + output_shape = input_shape[0][1:] + for i in range(1, len(input_shape)): + if input_shape[i] is None: + shape = None + else: + shape = input_shape[i][1:] + output_shape = self._compute_elemwise_op_output_shape(output_shape, shape) + # If the inputs have different ranks, we have to reshape them + # to make them broadcastable. + if None not in input_shape and len(set(map(len, input_shape))) == 1: + self._reshape_required = False + else: + self._reshape_required = True def call(self, inputs): - return self._merge_function(inputs) + if self._reshape_required: + reshaped_inputs = [] + input_ndims = list(map(K.ndim, inputs)) + if None not in input_ndims: + # If ranks of all inputs are available, + # we simply expand each of them at axis=1 + # until all of them have the same rank. + max_ndim = max(input_ndims) + for x in inputs: + x_ndim = K.ndim(x) + for _ in range(max_ndim - x_ndim): + x = K.expand_dims(x, 1) + reshaped_inputs.append(x) + return self._merge_function(reshaped_inputs) + else: + # Transpose all inputs so that batch size is the last dimension. + # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size) + transposed = False + for x in inputs: + x_ndim = K.ndim(x) + if x_ndim is None: + x_shape = K.shape(x) + batch_size = x_shape[0] + new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)]) + x_transposed = K.reshape(x, + K.stack([batch_size, K.prod(x_shape[1:])])) + x_transposed = K.permute_dimensions(x_transposed, (1, 0)) + x_transposed = K.reshape(x_transposed, new_shape) + reshaped_inputs.append(x_transposed) + transposed = True + elif x_ndim > 1: + dims = list(range(1, x_ndim)) + [0] + reshaped_inputs.append(K.permute_dimensions(x, dims)) + transposed = True + else: + # We don't transpose inputs if they are 1D vectors or scalars. + reshaped_inputs.append(x) + y = self._merge_function(reshaped_inputs) + y_ndim = K.ndim(y) + if transposed: + # If inputs have been transposed, we have to transpose the output too. + if y_ndim is None: + y_shape = K.shape(y) + y_ndim = K.shape(y_shape)[0] + batch_size = y_shape[y_ndim - 1] + new_shape = K.concatenate( + [K.expand_dims(batch_size), y_shape[:y_ndim - 1]]) + y = K.reshape(y, (-1, batch_size)) + y = K.permute_dimensions(y, (1, 0)) + y = K.reshape(y, new_shape) + elif y_ndim > 1: + dims = [y_ndim - 1] + list(range(y_ndim - 1)) + y = K.permute_dimensions(y, dims) + return y + else: + return self._merge_function(inputs) + + def compute_output_shape(self, input_shape): + if input_shape[0] is None: + output_shape = None + else: + output_shape = input_shape[0][1:] + for i in range(1, len(input_shape)): + if input_shape[i] is None: + shape = None + else: + shape = input_shape[i][1:] + output_shape = self._compute_elemwise_op_output_shape(output_shape, shape) + batch_sizes = [s[0] for s in input_shape if s is not None] + batch_sizes = set(batch_sizes) + batch_sizes -= set([None]) + if len(batch_sizes) == 1: + output_shape = (list(batch_sizes)[0],) + output_shape + else: + output_shape = (None,) + output_shape + return output_shape def compute_mask(self, inputs, mask=None): if mask is None: @@ -219,8 +341,8 @@ def compute_mask(self, inputs, mask=None): for input_i, mask_i in zip(inputs, mask): if mask_i is None: # Input is unmasked. Append all 1s to masks, - # but cast it to uint8 first - masks.append(K.cast(K.ones_like(input_i), 'uint8')) + # but cast it to bool first + masks.append(K.cast(K.ones_like(input_i), 'bool')) elif K.ndim(mask_i) < K.ndim(input_i): # Mask is smaller than the input, expand it masks.append(K.expand_dims(mask_i)) diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization.py b/tensorflow/contrib/keras/python/keras/layers/normalization.py index 41c618cc79d6d8..d429cd6d9ba913 100644 --- a/tensorflow/contrib/keras/python/keras/layers/normalization.py +++ b/tensorflow/contrib/keras/python/keras/layers/normalization.py @@ -154,7 +154,7 @@ def call(self, inputs, training=None): broadcast_shape[self.axis] = input_shape[self.axis] # Determines whether broadcasting is needed. - needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1]) + needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1]) normed, mean, variance = K.normalize_batch_in_training( inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon) diff --git a/tensorflow/contrib/keras/python/keras/layers/pooling.py b/tensorflow/contrib/keras/python/keras/layers/pooling.py index e31caed3ecccc7..47c88bf4d0bd42 100644 --- a/tensorflow/contrib/keras/python/keras/layers/pooling.py +++ b/tensorflow/contrib/keras/python/keras/layers/pooling.py @@ -199,9 +199,9 @@ class MaxPooling2D(_Pooling2D): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -255,9 +255,9 @@ class AveragePooling2D(_Pooling2D): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -542,9 +542,9 @@ class GlobalAveragePooling2D(_GlobalPooling2D): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". @@ -577,9 +577,9 @@ class GlobalMaxPooling2D(_GlobalPooling2D): one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. `channels_last` corresponds to inputs with shape - `(batch, width, height, channels)` while `channels_first` + `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape - `(batch, channels, width, height)`. + `(batch, channels, height, width)`. It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/contrib/keras/python/keras/layers/recurrent.py index 06986d3eaad812..6301132f4d2367 100644 --- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py +++ b/tensorflow/contrib/keras/python/keras/layers/recurrent.py @@ -105,8 +105,16 @@ class Recurrent(Layer): # now model.output_shape == (None, 32) # note: `None` is the batch dimension. - # for subsequent layers, not need to specify the input size: + # for subsequent layers, no need to specify the input size: model.add(LSTM(16)) + + # to stack recurrent layers, you must use return_sequences=True + # on any recurrent layer that feeds into another recurrent layer. + # note that you only need to specify the input size on the first layer. + model = Sequential() + model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True)) + model.add(LSTM(32, return_sequences=True)) + model.add(LSTM(10)) ``` Arguments: @@ -116,7 +124,8 @@ class Recurrent(Layer): return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. go_backwards: Boolean (default False). - If True, process the input sequence backwards. + If True, process the input sequence backwards and return the + reversed sequence. stateful: Boolean (default False). If True, the last state for each sample at index i in a batch will be used as initial state for the sample of index i in the following batch. @@ -398,6 +407,7 @@ class SimpleRNN(Recurrent): units: Positive integer, dimensionality of the output space. activation: Activation function to use. If you don't specify anything, no activation is applied + If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, @@ -547,7 +557,7 @@ def step(self, inputs, states): def get_constants(self, inputs, training=None): constants = [] - if self.implementation == 0 and 0 < self.dropout < 1: + if self.implementation != 0 and 0 < self.dropout < 1: input_shape = K.int_shape(inputs) input_dim = input_shape[-1] ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) @@ -619,7 +629,7 @@ class GRU(Recurrent): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. - If you don't specify anything, no activation is applied + If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. @@ -792,7 +802,7 @@ def preprocess_input(self, inputs, training=None): def get_constants(self, inputs, training=None): constants = [] - if self.implementation == 0 and 0 < self.dropout < 1: + if self.implementation != 0 and 0 < self.dropout < 1: input_shape = K.int_shape(inputs) input_dim = input_shape[-1] ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) @@ -861,7 +871,7 @@ def step(self, inputs, states): if self.use_bias: x_z = K.bias_add(x_z, self.bias_z) x_r = K.bias_add(x_r, self.bias_r) - x_h = K.bias_add(x_r, self.bias_h) + x_h = K.bias_add(x_h, self.bias_h) else: raise ValueError('Unknown `implementation` mode.') z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0], @@ -924,7 +934,7 @@ class LSTM(Recurrent): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. - If you don't specify anything, no activation is applied + If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. @@ -1127,7 +1137,7 @@ def preprocess_input(self, inputs, training=None): def get_constants(self, inputs, training=None): constants = [] - if self.implementation == 0 and 0 < self.dropout < 1: + if self.implementation != 0 and 0 < self.dropout < 1: input_shape = K.int_shape(inputs) input_dim = input_shape[-1] ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1))) diff --git a/tensorflow/contrib/keras/python/keras/layers/wrappers.py b/tensorflow/contrib/keras/python/keras/layers/wrappers.py index 75b4810e40bd2f..eeb67493ee3f4d 100644 --- a/tensorflow/contrib/keras/python/keras/layers/wrappers.py +++ b/tensorflow/contrib/keras/python/keras/layers/wrappers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +# pylint: disable=protected-access """Wrapper layers: layers that augment the functionality of another layer. """ from __future__ import absolute_import @@ -19,6 +20,7 @@ from __future__ import print_function import copy +import inspect from tensorflow.contrib.keras.python.keras import backend as K from tensorflow.contrib.keras.python.keras.engine import InputSpec @@ -70,9 +72,10 @@ def get_config(self): return dict(list(base_config.items()) + list(config.items())) @classmethod - def from_config(cls, config): + def from_config(cls, config, custom_objects=None): from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top - layer = deserialize_layer(config.pop('layer')) + layer = deserialize_layer( + config.pop('layer'), custom_objects=custom_objects) return cls(layer, **config) @@ -188,12 +191,15 @@ class Bidirectional(Wrapper): If None, the outputs will not be combined, they will be returned as a list. + Raises: + ValueError: In case of invalid `merge_mode` argument. + Examples: ```python model = Sequential() model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, - 10))) + 10))) model.add(Bidirectional(LSTM(10))) model.add(Dense(5)) model.add(Activation('softmax')) @@ -242,29 +248,47 @@ def _compute_output_shape(self, input_shape): shape = self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access return [shape, copy.copy(shape)] - def call(self, inputs, mask=None): - y = self.forward_layer.call(inputs, mask) - y_rev = self.backward_layer.call(inputs, mask) + def call(self, inputs, training=None, mask=None): + kwargs = {} + func_args = inspect.getargspec(self.layer.call).args + if 'training' in func_args: + kwargs['training'] = training + if 'mask' in func_args: + kwargs['mask'] = mask + + y = self.forward_layer.call(inputs, **kwargs) + y_rev = self.backward_layer.call(inputs, **kwargs) if self.return_sequences: y_rev = K.reverse(y_rev, 1) if self.merge_mode == 'concat': - return K.concatenate([y, y_rev]) + output = K.concatenate([y, y_rev]) elif self.merge_mode == 'sum': - return y + y_rev + output = y + y_rev elif self.merge_mode == 'ave': - return (y + y_rev) / 2 + output = (y + y_rev) / 2 elif self.merge_mode == 'mul': - return y * y_rev + output = y * y_rev elif self.merge_mode is None: - return [y, y_rev] + output = [y, y_rev] + + # Properly set learning phase + if 0 < self.layer.dropout + self.layer.recurrent_dropout: + if self.merge_mode is None: + for out in output: + out._uses_learning_phase = True + else: + output._uses_learning_phase = True + return output def reset_states(self): self.forward_layer.reset_states() self.backward_layer.reset_states() def build(self, input_shape): - self.forward_layer.build(input_shape) - self.backward_layer.build(input_shape) + with K.name_scope(self.forward_layer.name): + self.forward_layer.build(input_shape) + with K.name_scope(self.backward_layer.name): + self.backward_layer.build(input_shape) self.built = True def compute_mask(self, inputs, mask): diff --git a/tensorflow/contrib/keras/python/keras/metrics.py b/tensorflow/contrib/keras/python/keras/metrics.py index d7266c94cf78c0..59d380f73bd859 100644 --- a/tensorflow/contrib/keras/python/keras/metrics.py +++ b/tensorflow/contrib/keras/python/keras/metrics.py @@ -43,12 +43,15 @@ def binary_accuracy(y_true, y_pred): def categorical_accuracy(y_true, y_pred): - return K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)) + return K.cast( + K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx()) def sparse_categorical_accuracy(y_true, y_pred): - return K.equal( - K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), K.floatx())) + return K.cast( + K.equal( + K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), + K.floatx())), K.floatx()) def top_k_categorical_accuracy(y_true, y_pred, k=5): diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py index 2be4431d03d397..5289bb732b1334 100644 --- a/tensorflow/contrib/keras/python/keras/models.py +++ b/tensorflow/contrib/keras/python/keras/models.py @@ -207,7 +207,7 @@ def load_model(filepath, custom_objects=None): ValueError: In case of an invalid savefile. """ if h5py is None: - raise ImportError('`save_model` requires h5py.') + raise ImportError('`load_model` requires h5py.') if not custom_objects: custom_objects = {} @@ -1006,7 +1006,7 @@ def fit_generator(self, steps_per_epoch: Total number of steps (batches of samples) to yield from `generator` before declaring one epoch finished and starting the next epoch. It should typically - be equal to the number of unique samples if your dataset + be equal to the number of unique samples of your dataset divided by the batch size. epochs: Integer, total number of iterations on the data. verbose: Verbosity mode, 0, 1, or 2. @@ -1017,8 +1017,10 @@ def fit_generator(self, - A tuple (inputs, targets, sample_weights). validation_steps: Only relevant if `validation_data` is a generator. - Number of samples to use from validation generator - at the end of every epoch. + Number of steps to yield from validation generator + at the end of every epoch. It should typically + be equal to the number of unique samples of your + validation dataset divided by the batch size. class_weight: Dictionary mapping class indices to a weight for the class. max_q_size: Maximum size for the generator queue @@ -1050,7 +1052,7 @@ def generate_arrays_from_file(path): # and labels, from each line in the file x, y = process_line(line) yield (x, y) - f.close() + f.close() model.fit_generator(generate_arrays_from_file('/my_file.txt'), samples_per_epoch=10000, epochs=10) @@ -1119,7 +1121,8 @@ def predict_generator(self, steps, max_q_size=10, workers=1, - pickle_safe=False): + pickle_safe=False, + verbose=0): """Generates predictions for the input samples from a data generator. The generator should return the same kind of data as accepted by @@ -1136,6 +1139,7 @@ def predict_generator(self, relies on multiprocessing, you should not pass non picklable arguments to the generator as they can't be passed easily to children processes. + verbose: verbosity mode, 0 or 1. Returns: A Numpy array of predictions. @@ -1147,7 +1151,8 @@ def predict_generator(self, steps, max_q_size=max_q_size, workers=workers, - pickle_safe=pickle_safe) + pickle_safe=pickle_safe, + verbose=verbose) def get_config(self): config = [] @@ -1159,9 +1164,9 @@ def get_config(self): return copy.deepcopy(config) @classmethod - def from_config(cls, config): + def from_config(cls, config, custom_objects=None): model = cls() for conf in config: - layer = layer_module.deserialize(conf) + layer = layer_module.deserialize(conf, custom_objects=custom_objects) model.add(layer) return model diff --git a/tensorflow/contrib/keras/python/keras/preprocessing/image.py b/tensorflow/contrib/keras/python/keras/preprocessing/image.py index 86c7650a073b2e..de0749ae020219 100644 --- a/tensorflow/contrib/keras/python/keras/preprocessing/image.py +++ b/tensorflow/contrib/keras/python/keras/preprocessing/image.py @@ -785,7 +785,7 @@ def _flow_index(self, n, batch_size=32, shuffle=False, seed=None): index_array = np.random.permutation(n) current_index = (self.batch_index * batch_size) % n - if n >= current_index + batch_size: + if n > current_index + batch_size: current_batch_size = batch_size self.batch_index += 1 else: diff --git a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py index c1e02968353bc4..6e83fde2c9089b 100644 --- a/tensorflow/contrib/keras/python/keras/utils/generic_utils.py +++ b/tensorflow/contrib/keras/python/keras/utils/generic_utils.py @@ -172,7 +172,8 @@ def deserialize_keras_object(identifier, else: fn = module_objects.get(function_name) if fn is None: - raise ValueError('Unknown ' + printable_module_name, ':' + class_name) + raise ValueError('Unknown ' + printable_module_name, + ':' + function_name) return fn else: raise ValueError('Could not interpret serialized ' + printable_module_name + @@ -215,6 +216,8 @@ def func_load(code, defaults=None, closure=None, globs=None): """ if isinstance(code, (tuple, list)): # unpack previous dump code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) code = marshal.loads(code.encode('raw_unicode_escape')) if globs is None: globs = globals() diff --git a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py index 32e0de7d3dc19a..26878fdd57fb5b 100644 --- a/tensorflow/contrib/keras/python/keras/utils/layer_utils.py +++ b/tensorflow/contrib/keras/python/keras/utils/layer_utils.py @@ -171,7 +171,7 @@ def count_total_params(layers, layer_set=None): [K.count_params(p) for p in layer.trainable_weights]) non_trainable_count += np.sum( [K.count_params(p) for p in layer.non_trainable_weights]) - return trainable_count, non_trainable_count + return int(trainable_count), int(non_trainable_count) def convert_all_kernels_in_model(model): diff --git a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py index ecda890fec966e..323c31aee839aa 100644 --- a/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py +++ b/tensorflow/contrib/keras/python/keras/wrappers/scikit_learn.py @@ -194,6 +194,36 @@ class KerasClassifier(BaseWrapper): """Implementation of the scikit-learn classifier API for Keras. """ + def fit(self, x, y, **kwargs): + """Constructs a new model with `build_fn` & fit the model to `(x, y)`. + + Arguments: + x : array-like, shape `(n_samples, n_features)` + Training samples where n_samples in the number of samples + and n_features is the number of features. + y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)` + True labels for X. + **kwargs: dictionary arguments + Legal arguments are the arguments of `Sequential.fit` + + Returns: + history : object + details about the training history at each epoch. + + Raises: + ValueError: In case of invalid shape for `y` argument. + """ + y = np.array(y) + if len(y.shape) == 2 and y.shape[1] > 1: + self.classes_ = np.arange(y.shape[1]) + elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1: + self.classes_ = np.unique(y) + y = np.searchsorted(self.classes_, y) + else: + raise ValueError('Invalid shape for y: ' + str(y.shape)) + self.n_classes_ = len(self.classes_) + return super(KerasClassifier, self).fit(x, y, **kwargs) + def predict(self, x, **kwargs): """Returns the class predictions for the given test data. @@ -210,7 +240,8 @@ def predict(self, x, **kwargs): Class predictions. """ kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs) - return self.model.predict_classes(x, **kwargs) + classes = self.model.predict_classes(x, **kwargs) + return self.classes_[classes] def predict_proba(self, x, **kwargs): """Returns class probability estimates for the given test data. @@ -261,6 +292,7 @@ def score(self, x, y, **kwargs): compute accuracy. You should pass `metrics=["accuracy"]` to the `.compile()` method of the model. """ + y = np.searchsorted(self.classes_, y) kwargs = self.filter_sk_params(Sequential.evaluate, kwargs) loss_name = self.model.loss