From c4d36d06f6465e82a50d400d114259e07b8b0709 Mon Sep 17 00:00:00 2001 From: vieting <45091115+vieting@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:49:11 +0100 Subject: [PATCH] ConvLayer: add pad_seq_len_to_power (#1468) --- returnn/tf/layers/basic.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index ce1c6f520b..0d9adee741 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -6120,6 +6120,7 @@ def __init__( filter_perm=None, bias=None, use_time_mask=False, + pad_seq_len_to_power=None, **kwargs, ): """ @@ -6158,6 +6159,10 @@ def __init__( :param dict[str,str]|None filter_perm: transposes the filter (input filter as layer) :param LayerBase|None bias: if given, will not create an own parameter, but use this as the bias :param bool use_time_mask: + :param Optional[float] pad_seq_len_to_power: pad sequence length to power of given number + to reduce number of different sequence lengths. + See https://github.com/rwth-i6/returnn/issues/1450 and + https://github.com/tensorflow/tensorflow/issues/62441. """ from returnn.util import BehaviorVersion @@ -6301,6 +6306,22 @@ def __init__( else: x = input_data.placeholder + if pad_seq_len_to_power is None: + pad_seq_len_to_power = self.network.get_config().float("conv_pad_seq_len_to_power", None) + if pad_seq_len_to_power is not None: + pad_seq_len_to_power = float(pad_seq_len_to_power) + padding_for_power = [] + for ax in range(input_data.batch_ndim): + if input_data.is_axis_dynamic(ax): + seq_len = tf.cast(tf.shape(x)[ax], tf.float32) + padded_len = tf.math.ceil( + pad_seq_len_to_power ** (tf.math.ceil(tf.math.log(seq_len) / tf.math.log(pad_seq_len_to_power))) + ) + padding_for_power.append((0, padded_len - seq_len)) + else: + padding_for_power.append((0, 0)) + x = tf.pad(x, padding_for_power) + extended_batch_shape = None if num_batch_dims > 1: x_shape = tf.shape(x) @@ -6361,6 +6382,16 @@ def __init__( ) if num_batch_dims > 1: y = tf.reshape(y, tf.concat([extended_batch_shape, tf.shape(y)[1:]], axis=0)) + + if pad_seq_len_to_power is not None: + slice_size = [] + for ax in range(self.output.batch_ndim): + if self.output.is_axis_dynamic(ax): + slice_size.append(tf.reduce_max(self.output.get_dynamic_size(ax))) + else: + slice_size.append(tf.shape(y)[ax]) + y = tf.slice(y, begin=[0] * len(slice_size), size=slice_size) + # y shape is [batch] + dynamic_dims + [n_out]. if with_bias is NotSpecified: if bias or BehaviorVersion.get() >= 10: