From dc1c4114d5a531ff077e5475307fe4892d1de29d Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Mon, 9 Oct 2023 16:06:48 +0200 Subject: [PATCH] Dim dim value precomputed --- returnn/tensor/_dim_extra.py | 116 ++++++++++++++++++++--------------- returnn/tensor/dim.py | 2 + 2 files changed, 69 insertions(+), 49 deletions(-) diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index 744fb8f8d9..1fcc0fd805 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -135,6 +135,7 @@ class _DimMixin: capacity: Optional[int] size: Optional[int] dyn_size_ext: Optional[_t.Tensor] + _dyn_size_max_value: Optional[_t.Tensor] # scalar _extra: Optional[_DimExtra] def _handle_extra_kwargs(self: Dim, *, dyn_size: Optional[_t.RawTensorType] = None, **kwargs): @@ -350,6 +351,7 @@ def reset_eager(self: Dim): This resets everything related. This can also include caches. """ + self._dyn_size_max_value = None if self.dyn_size_ext: self.dyn_size_ext.raw_tensor = None if self._extra: @@ -1002,26 +1004,26 @@ def complete_dyn_size(self, template_only=False): if not op: return - for x in op.inputs: + for x_dim in op.inputs: if self.batch: - x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) - x.complete_dyn_size(template_only=template_only) + x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx) + x_dim.complete_dyn_size(template_only=template_only) backend = None - for x in op.inputs: + for x_dim in op.inputs: if self.batch: - x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) - if x.dyn_size_ext and x.dyn_size_ext.raw_tensor is not None: + x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx) + if x_dim.dyn_size_ext and x_dim.dyn_size_ext.raw_tensor is not None: # noinspection PyProtectedMember - backend = x.dyn_size_ext._raw_backend + backend = x_dim.dyn_size_ext._raw_backend break size_dtype = None - for x in op.inputs: + for x_dim in op.inputs: if self.batch: - x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) - if x.dyn_size_ext: - size_dtype = x.dyn_size_ext.dtype + x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx) + if x_dim.dyn_size_ext: + size_dtype = x_dim.dyn_size_ext.dtype break if not size_dtype: size_dtype = _t.Tensor.size_dtype @@ -1090,6 +1092,16 @@ def _bin_op(a, b): if isinstance(b, _t.Tensor): return b raise Exception(f"Dim complete_dyn_size: bin_op: expect to get one Tensor, got {a} and {b}") + if tf: + assert isinstance(a, _t.Tensor) + if isinstance(b, _t.Tensor): + res = _t.Tensor.get_common_data([a, b], allow_broadcast_all_sources=True) + a = a.copy_compatible_to_dims(res.dims) if a.dims else a + b = b.copy_compatible_to_dims(res.dims) if b.dims else b + else: + res = a.copy_template() + res.raw_tensor = _bin_op_tf(a.raw_tensor, b.raw_tensor if isinstance(b, _t.Tensor) else b) + return res if kind == "add": return _relu(rf.combine_bc(a, "add", b)) elif kind == "sub": @@ -1113,16 +1125,18 @@ def _relu(a): y_name = self.description + ":seq-length" y: Optional[_t.Tensor] = None # resulting dyn size + y_max_value: Optional[_t.Tensor] = None # resulting dyn size max value inputs = list(op.inputs) assert inputs while inputs: - x = inputs.pop(0) - if not x.is_dynamic(): # static - assert x.dimension is not None + x_dim: Dim = inputs.pop(0) + if not x_dim.is_dynamic(): # static + assert x_dim.dimension is not None if y is None: if not template_only and backend and not tf: - with rf.set_default_device_ctx(None): - y = backend.convert_to_tensor(x.dimension, dims=[], dtype=size_dtype, name=y_name) + y = backend.convert_to_tensor( + x_dim.dimension, dims=[], dtype=size_dtype, name=y_name, device="cpu" + ) else: y = _t.Tensor( name=y_name, @@ -1131,31 +1145,24 @@ def _relu(a): ) if not template_only and tf: with tf.control_dependencies(None): # this will reset the context - y.raw_tensor = tf.constant(x.dimension) + y.raw_tensor = tf.constant(x_dim.dimension) + y_max_value = y.copy() continue - if tf: - y.placeholder = _bin_op_tf(y.placeholder, x.dimension) - else: - y = _bin_op(y, x.dimension) + y = _bin_op(y, x_dim.dimension) + y_max_value = _bin_op(y_max_value, x_dim.dimension) continue if self.batch: - x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx) - x.complete_dyn_size(template_only=template_only) - if not x.dyn_size_ext: + x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx) + x_dim.complete_dyn_size(template_only=template_only) + if not x_dim.dyn_size_ext: return - x = x.dyn_size_ext if y is None: - y = x.copy(name=y_name) + y = x_dim.dyn_size_ext.copy(name=y_name) + y_max_value = x_dim.get_dim_value_tensor() continue - if tf: - common = _t.Tensor.get_common_data([x, y], allow_broadcast_all_sources=True) - x_ = x.copy_compatible_to_dims(common.dims) if x.dims else x - y_ = y.copy_compatible_to_dims(common.dims) if y.dims else y - y = common - y.placeholder = _bin_op_tf(y_.placeholder, x_.placeholder) - else: - y = _bin_op(y, x) - assert y, f"op {op}?" + y = _bin_op(y, x_dim.dyn_size_ext) + y_max_value = _bin_op(y_max_value, x_dim.get_dim_value_tensor()) + assert y and y_max_value, f"op {op}?" if self.dyn_size_ext: assert self.dyn_size_ext.dim_tags == y.dim_tags if y.batch: @@ -1164,6 +1171,7 @@ def _relu(a): else: self.batch = y.batch self.dyn_size_ext = y + self._dyn_size_max_value = y_max_value if tf and y.placeholder is not None: self.set_tag_on_size_tensor(y.placeholder) @@ -1797,16 +1805,8 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]: if self.dimension is not None: return self.dimension - if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: # fast path - if self.dyn_size_ext.batch_ndim > 0: - return rf.reduce_max( - self.dyn_size_ext, - axis=self.dyn_size_ext.dim_tags, - # Masking is not always possible here, e.g. - # self = Dim{'self-att-keys'['time:var:extern_data:classes'[B]]}. - use_mask=False, - ) - return self.dyn_size_ext + if self._dyn_size_max_value is not None: # fast path, precomputed + return self._dyn_size_max_value if self.is_batch_dim(): res = None if self._extra and self._extra.src_data: @@ -1816,7 +1816,9 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]: if isinstance(res, int): return res if res is not None: - return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res) + res = _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res) + self._dyn_size_max_value = res + return res if ( self._extra and self._extra.src_data is not None @@ -1826,12 +1828,28 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]: res = self._extra.src_data.get_dim(self._extra.src_axis) if isinstance(res, int): return res - return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res) + return _t.Tensor( + f"{self._extra.src_data}:shape[{self._extra.src_axis}]", + dims=(), + dtype=rf.get_default_array_index_dtype(), + raw_tensor=res, + ) self.complete_dyn_size() + if self._dyn_size_max_value is not None: + return self._dyn_size_max_value if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: if self.dyn_size_ext.batch_ndim > 0: - return rf.reduce_max(self.dyn_size_ext, axis=self.dyn_size_ext.dim_tags) - return self.dyn_size_ext + res = rf.reduce_max( + self.dyn_size_ext, + axis=self.dyn_size_ext.dim_tags, + # Masking is not always possible here, e.g. + # self = Dim{'self-att-keys'['time:var:extern_data:classes'[B]]}. + use_mask=False, + ) + else: + res = self.dyn_size_ext + self._dyn_size_max_value = res + return res raise Exception("%s: need placeholder, self.dimension or self.dyn_size for dim value" % self) def axis_split_info(self): diff --git a/returnn/tensor/dim.py b/returnn/tensor/dim.py index 8bc612f1c4..24dc1c35dd 100644 --- a/returnn/tensor/dim.py +++ b/returnn/tensor/dim.py @@ -51,6 +51,7 @@ class Dim(_DimMixin): capacity: Optional[int] # shape[axis] in the raw tensor (might need power-of-two or static shape), None if dynamic size: Optional[int] # shape[axis] in the represented tensor if static, None if dynamic, then dyn_size_ext dyn_size_ext: Optional[_t.Tensor] + _dyn_size_max_value: Optional[_t.Tensor] # scalar _extra: Optional[_DimExtra] def __init__( @@ -84,6 +85,7 @@ def __init__( if not name and not description and self.dyn_size_ext: name = self.dyn_size_ext.name self.name = name or description + self._dyn_size_max_value = None self._extra = None if kwargs: