Skip to content
Jos van de Wolfshaar edited this page Feb 27, 2018 · 1 revision

LIP 11 - Reduce logsum with custom Op

LIP 11
Title Reduce logsum with custom Op
Author J. van de Wolfshaar
Status Draft
Type Standard
Discussion Issue #41
PR
Created Feb 27, 2018

Introduction

A key operation for computing the log-space values of Sum nodes is performing reduce_logsum. The current implementation uses several Ops to accomplish this. This LIP proposes to replace the mechanics of the current reduce_logsum with a custom Op.

Technical Background

In utils/math.py, the following function computes the sum over the last axis of its input in log-space:

def reduce_log_sum_v2(log_input, name=None):
    """Calculate log of a sum of elements of a tensor containing log values
    row-wise. This function implements this through a series of TF ops. See

    Args:
        log_input (Tensor): Tensor containing log values.

    Returns:
        Tensor: The reduced tensor of shape ``(None, 1)``, where the first
        dimension corresponds to the first dimension of ``log_input``.
    """
    with tf.name_scope(name, "reduce_log_sum", [log_input]):
        log_max = tf.reduce_max(log_input, -1, keepdims=True)
        # Compute the value assuming at least one input is not -inf
        log_rebased = tf.subtract(log_input, log_max)
        out_normal = log_max + tf.log(tf.reduce_sum(tf.exp(log_rebased),
                                                    -1, keepdims=True))
        # Check if all input values in a row are -inf (all non-log inputs are 0)
        # and produce output for that case
        # We use float('inf') for compatibility with Python<3.5
        # For Python>=3.5 we can use math.inf instead
        all_zero = tf.equal(log_max,
                            tf.constant(-float('inf'), dtype=log_input.dtype))
        out_zeros = tf.fill(tf.shape(out_normal),
                            tf.constant(-float('inf'), dtype=log_input.dtype))
        # Choose the output for each row
        return tf.where(all_zero, out_zeros, out_normal)

Note that TensorFlow also has its own implementation:

@deprecation.deprecated_args(
    None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_logsumexp(input_tensor,
                     axis=None,
                     keepdims=None,
                     name=None,
                     reduction_indices=None,
                     keep_dims=None):
  """Computes log(sum(exp(elements across dimensions of a tensor))).

  Reduces `input_tensor` along the dimensions given in `axis`.
  Unless `keepdims` is true, the rank of the tensor is reduced by 1 for each
  entry in `axis`. If `keepdims` is true, the reduced dimensions
  are retained with length 1.

  If `axis` has no entries, all dimensions are reduced, and a
  tensor with a single element is returned.

  This function is more numerically stable than log(sum(exp(input))). It avoids
  overflows caused by taking the exp of large inputs and underflows caused by
  taking the log of small inputs.

  Args:
    input_tensor: The tensor to reduce. Should have numeric type.
    axis: The dimensions to reduce. If `None` (the default),
      reduces all dimensions. Must be in the range
      `[-rank(input_tensor), rank(input_tensor))`.
    keepdims: If true, retains reduced dimensions with length 1.
    name: A name for the operation (optional).
    reduction_indices: The old (deprecated) name for axis.
    keep_dims: Deprecated alias for `keepdims`.

  Returns:
    The reduced tensor.
  """
  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
                                                    "keep_dims", keep_dims)
  if keepdims is None:
    keepdims = False
  with ops.name_scope(name, "ReduceLogSumExp", [input_tensor]) as name:
    raw_max = reduce_max(
        input_tensor,
        axis=axis,
        reduction_indices=reduction_indices,
        keepdims=True)
    my_max = array_ops.stop_gradient(
        array_ops.where(
            gen_math_ops.is_finite(raw_max), raw_max,
            array_ops.zeros_like(raw_max)))
    result = gen_math_ops.log(
        reduce_sum(
            gen_math_ops.exp(input_tensor - my_max),
            axis,
            keepdims=True,
            reduction_indices=reduction_indices)) + my_max
    if not keepdims:
      if isinstance(axis, int):
        axis = [axis]
      result = array_ops.squeeze(result, axis)
    return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)

This tensor computes the reduction over an axis, which involves subtracting the maximum over the axis first, then performing exp, reduce_sum (same axis), log and finally adding the max values that were found in the first step. Infinite max values need special treatment: they are set to zero in case of TensorFlow's implementation. In the current implementation of reduce_logsum in utils/math.py, the result is defined to be -inf as soon as the maximum is also -inf for the same row.

Proposal

We can implement a custom Op that will perform all of the above in a single OpKernel.

Performance comparison

The comparison below is performed on a i7-4200HQ | GTX 960 system. It shows how the graph size for MNIST training task with MARGINAL inference is reduced by 13 percent. Its rest run time is also decreased by 7 percent.

#-----------------------
InferenceType: MARGINAL-LOG
#-----------------------
 CPU          op   multi_nodes  SPN_size  TF_size  mem_used        name  setup_time weights_init_time  first_run_time  rest_run_time test_accuracy
 GPU          op   multi_nodes  SPN_size  TF_size  mem_used        name  setup_time weights_init_time  first_run_time  rest_run_time test_accuracy
        mnist_01        True       1400   59636    507.5873     current      38564.98         3607.40         5468.59           1811.08  99.5272
        mnist_01        True       1400   51875    507.5873      custom      33463.82         3373.99         4433.15           1692.15  99.6690

The results on a E5-1650v3 system are:

+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+----------------------------+
| op_name   | on_gpu   | multi_nodes   |   spn_size |   tf_size |   memroy_used | input_dist   |   setup_time |   weights_init_time |   first_run_time |   rest_run_time |   test_accuracy | config                     |
|-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+----------------------------|
| mnist_01  | True     | True          |       1400 |     59636 |     500634624 | MIXTURE      |      30.6658 |             3.07669 |          4037.92 |         1248.58 |        0.996217 | custom_reduce_logsum=False |
| mnist_01  | True     | True          |       1400 |     51875 |     505362176 | MIXTURE      |      24.3156 |             2.78745 |          3117.8  |         1114.43 |        0.99669  | custom_reduce_logsum=True  |
+-----------+----------+---------------+------------+-----------+---------------+--------------+--------------+---------------------+------------------+-----------------+-----------------+----------------------------+

Decision

Clone this wiki locally