-
Notifications
You must be signed in to change notification settings - Fork 3
LIP0011
LIP | 11 |
---|---|
Title | Reduce logsum with custom Op |
Author | J. van de Wolfshaar |
Status | Draft |
Type | Standard |
Discussion | Issue #41 |
PR | |
Created | Feb 27, 2018 |
A key operation for computing the log-space values of Sum
nodes is performing reduce_logsum
.
The current implementation uses several Ops
to accomplish this. This LIP proposes to replace the
mechanics of the current reduce_logsum
with a custom Op.
In utils/math.py
, the following function computes the sum over the last axis of its input in
log-space:
def reduce_log_sum_v2(log_input, name=None):
"""Calculate log of a sum of elements of a tensor containing log values
row-wise. This function implements this through a series of TF ops. See
Args:
log_input (Tensor): Tensor containing log values.
Returns:
Tensor: The reduced tensor of shape ``(None, 1)``, where the first
dimension corresponds to the first dimension of ``log_input``.
"""
with tf.name_scope(name, "reduce_log_sum", [log_input]):
log_max = tf.reduce_max(log_input, -1, keepdims=True)
# Compute the value assuming at least one input is not -inf
log_rebased = tf.subtract(log_input, log_max)
out_normal = log_max + tf.log(tf.reduce_sum(tf.exp(log_rebased),
-1, keepdims=True))
# Check if all input values in a row are -inf (all non-log inputs are 0)
# and produce output for that case
# We use float('inf') for compatibility with Python<3.5
# For Python>=3.5 we can use math.inf instead
all_zero = tf.equal(log_max,
tf.constant(-float('inf'), dtype=log_input.dtype))
out_zeros = tf.fill(tf.shape(out_normal),
tf.constant(-float('inf'), dtype=log_input.dtype))
# Choose the output for each row
return tf.where(all_zero, out_zeros, out_normal)
Note that TensorFlow also has its own implementation:
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_logsumexp(input_tensor,
axis=None,
keepdims=None,
name=None,
reduction_indices=None,
keep_dims=None):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
Reduces `input_tensor` along the dimensions given in `axis`.
Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
entry in `axis`. If `keepdims` is true, the reduced dimensions
are retained with length 1.
If `axis` has no entries, all dimensions are reduced, and a
tensor with a single element is returned.
This function is more numerically stable than log(sum(exp(input))). It avoids
overflows caused by taking the exp of large inputs and underflows caused by
taking the log of small inputs.
Args:
input_tensor: The tensor to reduce. Should have numeric type.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
keepdims: If true, retains reduced dimensions with length 1.
name: A name for the operation (optional).
reduction_indices: The old (deprecated) name for axis.
keep_dims: Deprecated alias for `keepdims`.
Returns:
The reduced tensor.
"""
keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
"keep_dims", keep_dims)
if keepdims is None:
keepdims = False
with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
raw_max = reduce_max(
input_tensor,
axis=axis,
reduction_indices=reduction_indices,
keepdims=True)
my_max = array_ops.stop_gradient(
array_ops.where(
gen_math_ops.is_finite(raw_max), raw_max,
array_ops.zeros_like(raw_max)))
result = gen_math_ops.log(
reduce_sum(
gen_math_ops.exp(input_tensor - my_max),
axis,
keepdims=True,
reduction_indices=reduction_indices)) + my_max
if not keepdims:
if isinstance(axis, int):
axis = [axis]
result = array_ops.squeeze(result, axis)
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
This tensor computes the reduction over an axis
, which involves subtracting the maximum over the
axis first, then performing exp
, reduce_sum
(same axis), log
and finally adding the max values
that were found in the first step. Infinite max
values need special treatment: they are set to
zero in case of TensorFlow's implementation. In the current implementation of reduce_logsum
in
utils/math.py
, the result is defined to be -inf
as soon as the maximum is also -inf
for the
same row.
We can implement a custom Op that will perform all of the above in a single OpKernel
.
The comparison below is performed on a i7-4200HQ | GTX 960 system. It shows how the graph size for MNIST training task with MARGINAL inference is reduced by 13 percent. Its rest run time is also decreased by 7 percent.
#-----------------------
InferenceType: MARGINAL-LOG
#-----------------------
CPU op multi_nodes SPN_size TF_size mem_used name setup_time weights_init_time first_run_time rest_run_time test_accuracy
GPU op multi_nodes SPN_size TF_size mem_used name setup_time weights_init_time first_run_time rest_run_time test_accuracy
mnist_01 True 1400 59636 507.5873 current 38564.98 3607.40 5468.59 1811.08 99.5272
mnist_01 True 1400 51875 507.5873 custom 33463.82 3373.99 4433.15 1692.15 99.6690
The results on a E5-1650v3 system are:
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+----------------------------+
| op_name | on_gpu | multi_nodes | spn_size | tf_size | memroy_used | input_dist | setup_time | weights_init_time | first_run_time | rest_run_time | test_accuracy | config |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+----------------------------|
| mnist_01 | True | True | 1400 | 59636 | 500634624 | MIXTURE | 30.6658 | 3.07669 | 4037.92 | 1248.58 | 0.996217 | custom_reduce_logsum=False |
| mnist_01 | True | True | 1400 | 51875 | 505362176 | MIXTURE | 24.3156 | 2.78745 | 3117.8 | 1114.43 | 0.99669 | custom_reduce_logsum=True |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+----------------------------+