Skip to content

Commit

Permalink
Dim dim value precomputed
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 9, 2023
1 parent f1a678f commit dc1c411
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 49 deletions.
116 changes: 67 additions & 49 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class _DimMixin:
capacity: Optional[int]
size: Optional[int]
dyn_size_ext: Optional[_t.Tensor]
_dyn_size_max_value: Optional[_t.Tensor] # scalar
_extra: Optional[_DimExtra]

def _handle_extra_kwargs(self: Dim, *, dyn_size: Optional[_t.RawTensorType] = None, **kwargs):
Expand Down Expand Up @@ -350,6 +351,7 @@ def reset_eager(self: Dim):
This resets everything related.
This can also include caches.
"""
self._dyn_size_max_value = None
if self.dyn_size_ext:
self.dyn_size_ext.raw_tensor = None
if self._extra:
Expand Down Expand Up @@ -1002,26 +1004,26 @@ def complete_dyn_size(self, template_only=False):
if not op:
return

for x in op.inputs:
for x_dim in op.inputs:
if self.batch:
x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx)
x.complete_dyn_size(template_only=template_only)
x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx)
x_dim.complete_dyn_size(template_only=template_only)

backend = None
for x in op.inputs:
for x_dim in op.inputs:
if self.batch:
x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx)
if x.dyn_size_ext and x.dyn_size_ext.raw_tensor is not None:
x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx)
if x_dim.dyn_size_ext and x_dim.dyn_size_ext.raw_tensor is not None:
# noinspection PyProtectedMember
backend = x.dyn_size_ext._raw_backend
backend = x_dim.dyn_size_ext._raw_backend
break

size_dtype = None
for x in op.inputs:
for x_dim in op.inputs:
if self.batch:
x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx)
if x.dyn_size_ext:
size_dtype = x.dyn_size_ext.dtype
x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx)
if x_dim.dyn_size_ext:
size_dtype = x_dim.dyn_size_ext.dtype
break
if not size_dtype:
size_dtype = _t.Tensor.size_dtype
Expand Down Expand Up @@ -1090,6 +1092,16 @@ def _bin_op(a, b):
if isinstance(b, _t.Tensor):
return b
raise Exception(f"Dim complete_dyn_size: bin_op: expect to get one Tensor, got {a} and {b}")
if tf:
assert isinstance(a, _t.Tensor)
if isinstance(b, _t.Tensor):
res = _t.Tensor.get_common_data([a, b], allow_broadcast_all_sources=True)
a = a.copy_compatible_to_dims(res.dims) if a.dims else a
b = b.copy_compatible_to_dims(res.dims) if b.dims else b
else:
res = a.copy_template()
res.raw_tensor = _bin_op_tf(a.raw_tensor, b.raw_tensor if isinstance(b, _t.Tensor) else b)
return res
if kind == "add":
return _relu(rf.combine_bc(a, "add", b))
elif kind == "sub":
Expand All @@ -1113,16 +1125,18 @@ def _relu(a):

y_name = self.description + ":seq-length"
y: Optional[_t.Tensor] = None # resulting dyn size
y_max_value: Optional[_t.Tensor] = None # resulting dyn size max value
inputs = list(op.inputs)
assert inputs
while inputs:
x = inputs.pop(0)
if not x.is_dynamic(): # static
assert x.dimension is not None
x_dim: Dim = inputs.pop(0)
if not x_dim.is_dynamic(): # static
assert x_dim.dimension is not None
if y is None:
if not template_only and backend and not tf:
with rf.set_default_device_ctx(None):
y = backend.convert_to_tensor(x.dimension, dims=[], dtype=size_dtype, name=y_name)
y = backend.convert_to_tensor(
x_dim.dimension, dims=[], dtype=size_dtype, name=y_name, device="cpu"
)
else:
y = _t.Tensor(
name=y_name,
Expand All @@ -1131,31 +1145,24 @@ def _relu(a):
)
if not template_only and tf:
with tf.control_dependencies(None): # this will reset the context
y.raw_tensor = tf.constant(x.dimension)
y.raw_tensor = tf.constant(x_dim.dimension)
y_max_value = y.copy()
continue
if tf:
y.placeholder = _bin_op_tf(y.placeholder, x.dimension)
else:
y = _bin_op(y, x.dimension)
y = _bin_op(y, x_dim.dimension)
y_max_value = _bin_op(y_max_value, x_dim.dimension)
continue
if self.batch:
x = x.get_for_batch_ctx(self.batch, self.control_flow_ctx)
x.complete_dyn_size(template_only=template_only)
if not x.dyn_size_ext:
x_dim = x_dim.get_for_batch_ctx(self.batch, self.control_flow_ctx)
x_dim.complete_dyn_size(template_only=template_only)
if not x_dim.dyn_size_ext:
return
x = x.dyn_size_ext
if y is None:
y = x.copy(name=y_name)
y = x_dim.dyn_size_ext.copy(name=y_name)
y_max_value = x_dim.get_dim_value_tensor()
continue
if tf:
common = _t.Tensor.get_common_data([x, y], allow_broadcast_all_sources=True)
x_ = x.copy_compatible_to_dims(common.dims) if x.dims else x
y_ = y.copy_compatible_to_dims(common.dims) if y.dims else y
y = common
y.placeholder = _bin_op_tf(y_.placeholder, x_.placeholder)
else:
y = _bin_op(y, x)
assert y, f"op {op}?"
y = _bin_op(y, x_dim.dyn_size_ext)
y_max_value = _bin_op(y_max_value, x_dim.get_dim_value_tensor())
assert y and y_max_value, f"op {op}?"
if self.dyn_size_ext:
assert self.dyn_size_ext.dim_tags == y.dim_tags
if y.batch:
Expand All @@ -1164,6 +1171,7 @@ def _relu(a):
else:
self.batch = y.batch
self.dyn_size_ext = y
self._dyn_size_max_value = y_max_value
if tf and y.placeholder is not None:
self.set_tag_on_size_tensor(y.placeholder)

Expand Down Expand Up @@ -1797,16 +1805,8 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:

if self.dimension is not None:
return self.dimension
if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None: # fast path
if self.dyn_size_ext.batch_ndim > 0:
return rf.reduce_max(
self.dyn_size_ext,
axis=self.dyn_size_ext.dim_tags,
# Masking is not always possible here, e.g.
# self = Dim{'self-att-keys'['time:var:extern_data:classes'[B]]}.
use_mask=False,
)
return self.dyn_size_ext
if self._dyn_size_max_value is not None: # fast path, precomputed
return self._dyn_size_max_value
if self.is_batch_dim():
res = None
if self._extra and self._extra.src_data:
Expand All @@ -1816,7 +1816,9 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:
if isinstance(res, int):
return res
if res is not None:
return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
res = _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
self._dyn_size_max_value = res
return res
if (
self._extra
and self._extra.src_data is not None
Expand All @@ -1826,12 +1828,28 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:
res = self._extra.src_data.get_dim(self._extra.src_axis)
if isinstance(res, int):
return res
return _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
return _t.Tensor(
f"{self._extra.src_data}:shape[{self._extra.src_axis}]",
dims=(),
dtype=rf.get_default_array_index_dtype(),
raw_tensor=res,
)
self.complete_dyn_size()
if self._dyn_size_max_value is not None:
return self._dyn_size_max_value
if self.dyn_size_ext and self.dyn_size_ext.placeholder is not None:
if self.dyn_size_ext.batch_ndim > 0:
return rf.reduce_max(self.dyn_size_ext, axis=self.dyn_size_ext.dim_tags)
return self.dyn_size_ext
res = rf.reduce_max(
self.dyn_size_ext,
axis=self.dyn_size_ext.dim_tags,
# Masking is not always possible here, e.g.
# self = Dim{'self-att-keys'['time:var:extern_data:classes'[B]]}.
use_mask=False,
)
else:
res = self.dyn_size_ext
self._dyn_size_max_value = res
return res
raise Exception("%s: need placeholder, self.dimension or self.dyn_size for dim value" % self)

def axis_split_info(self):
Expand Down
2 changes: 2 additions & 0 deletions returnn/tensor/dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Dim(_DimMixin):
capacity: Optional[int] # shape[axis] in the raw tensor (might need power-of-two or static shape), None if dynamic
size: Optional[int] # shape[axis] in the represented tensor if static, None if dynamic, then dyn_size_ext
dyn_size_ext: Optional[_t.Tensor]
_dyn_size_max_value: Optional[_t.Tensor] # scalar
_extra: Optional[_DimExtra]

def __init__(
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
if not name and not description and self.dyn_size_ext:
name = self.dyn_size_ext.name
self.name = name or description
self._dyn_size_max_value = None
self._extra = None

if kwargs:
Expand Down

0 comments on commit dc1c411

Please sign in to comment.