Skip to content

Commit

Permalink
ConvLayer: add pad_seq_len_to_power (#1468)
Browse files Browse the repository at this point in the history
  • Loading branch information
vieting authored Dec 1, 2023
1 parent 8a7d673 commit c4d36d0
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6120,6 +6120,7 @@ def __init__(
filter_perm=None,
bias=None,
use_time_mask=False,
pad_seq_len_to_power=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -6158,6 +6159,10 @@ def __init__(
:param dict[str,str]|None filter_perm: transposes the filter (input filter as layer)
:param LayerBase|None bias: if given, will not create an own parameter, but use this as the bias
:param bool use_time_mask:
:param Optional[float] pad_seq_len_to_power: pad sequence length to power of given number
to reduce number of different sequence lengths.
See https://github.com/rwth-i6/returnn/issues/1450 and
https://github.com/tensorflow/tensorflow/issues/62441.
"""
from returnn.util import BehaviorVersion

Expand Down Expand Up @@ -6301,6 +6306,22 @@ def __init__(
else:
x = input_data.placeholder

if pad_seq_len_to_power is None:
pad_seq_len_to_power = self.network.get_config().float("conv_pad_seq_len_to_power", None)
if pad_seq_len_to_power is not None:
pad_seq_len_to_power = float(pad_seq_len_to_power)
padding_for_power = []
for ax in range(input_data.batch_ndim):
if input_data.is_axis_dynamic(ax):
seq_len = tf.cast(tf.shape(x)[ax], tf.float32)
padded_len = tf.math.ceil(
pad_seq_len_to_power ** (tf.math.ceil(tf.math.log(seq_len) / tf.math.log(pad_seq_len_to_power)))
)
padding_for_power.append((0, padded_len - seq_len))
else:
padding_for_power.append((0, 0))
x = tf.pad(x, padding_for_power)

extended_batch_shape = None
if num_batch_dims > 1:
x_shape = tf.shape(x)
Expand Down Expand Up @@ -6361,6 +6382,16 @@ def __init__(
)
if num_batch_dims > 1:
y = tf.reshape(y, tf.concat([extended_batch_shape, tf.shape(y)[1:]], axis=0))

if pad_seq_len_to_power is not None:
slice_size = []
for ax in range(self.output.batch_ndim):
if self.output.is_axis_dynamic(ax):
slice_size.append(tf.reduce_max(self.output.get_dynamic_size(ax)))
else:
slice_size.append(tf.shape(y)[ax])
y = tf.slice(y, begin=[0] * len(slice_size), size=slice_size)

# y shape is [batch] + dynamic_dims + [n_out].
if with_bias is NotSpecified:
if bias or BehaviorVersion.get() >= 10:
Expand Down

0 comments on commit c4d36d0

Please sign in to comment.