Skip to content

Commit

Permalink
RF custom_inv_norm_factor fix vector, device
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 19, 2023
1 parent 84a3ac9 commit 4216420
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 6 deletions.
8 changes: 7 additions & 1 deletion returnn/frontend/run_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,11 @@ def get_mean_loss(self) -> Tensor:
return self._mean_loss_cached
if self.custom_inv_norm_factor:
loss = self.get_summed_loss()
loss /= rf.cast(self.custom_inv_norm_factor, dtype=loss.dtype)
inv_norm = rf.reduce_sum(self.custom_inv_norm_factor, axis=self.custom_inv_norm_factor.dims)
inv_norm = rf.cast(inv_norm, loss.dtype)
inv_norm = rf.reciprocal(inv_norm)
inv_norm = rf.copy_to_device(inv_norm, loss.device)
loss *= inv_norm
return loss
if not self.loss.dims:
return self.loss
Expand All @@ -380,6 +384,8 @@ def get_inv_norm_factor(self) -> Union[int, Tensor]:
:return: inverse norm factor (scalar)
"""
if self.custom_inv_norm_factor:
if self.custom_inv_norm_factor.dims:
return rf.reduce_sum(self.custom_inv_norm_factor, axis=self.custom_inv_norm_factor.dims)
return self.custom_inv_norm_factor
return self.loss.num_elements()

Expand Down
57 changes: 53 additions & 4 deletions tests/rf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run_model(

print("** run with PyTorch backend")
with rft.TorchBackend.random_journal_record() as random_journal:
out_pt = run_model_torch(extern_data, get_model, forward_step)
out_pt = _run_model_torch(extern_data, get_model, forward_step)
_pad_mask_zeros(out_pt)
# get the values now because dims might get overwritten
out_pt_raw = out_pt.as_raw_tensor_dict(include_const_sizes=True)
Expand All @@ -60,7 +60,7 @@ def run_model(

print("** run with TensorFlow-net-dict backend")
with rfl.ReturnnLayersBackend.random_journal_replay(random_journal):
out_tf = run_model_net_dict_tf(extern_data, get_model, forward_step)
out_tf = _run_model_net_dict_tf(extern_data, get_model, forward_step)
_pad_mask_zeros(out_tf)
out_tf_raw = out_tf.as_raw_tensor_dict(include_const_sizes=True)

Expand Down Expand Up @@ -93,7 +93,7 @@ def run_model(
return out_pt


def run_model_torch(extern_data: TensorDict, get_model: rf.GetModelFunc, forward_step: rf.StepFunc) -> TensorDict:
def _run_model_torch(extern_data: TensorDict, get_model: rf.GetModelFunc, forward_step: rf.StepFunc) -> TensorDict:
"""run"""
extern_data_raw = extern_data.as_raw_tensor_dict(expected_value_type=numpy.ndarray)
rf.select_backend_torch()
Expand All @@ -115,7 +115,56 @@ def run_model_torch(extern_data: TensorDict, get_model: rf.GetModelFunc, forward
return outputs


def run_model_net_dict_tf(extern_data: TensorDict, get_model: rf.GetModelFunc, forward_step: rf.StepFunc) -> TensorDict:
def run_model_torch_train(
extern_data: TensorDict,
get_model: rf.GetModelFunc,
train_step: rf.StepFunc,
*,
dyn_dim_max_sizes: Optional[Dict[Dim, int]] = None,
dyn_dim_min_sizes: Optional[Dict[Dim, int]] = None,
) -> float:
"""run"""
rf.select_backend_torch()
rf.set_random_seed(42)

extern_data.reset_content()
tensor_dict_fill_random_numpy_(
extern_data, dyn_dim_max_sizes=dyn_dim_max_sizes, dyn_dim_min_sizes=dyn_dim_min_sizes
)
tensor_dict_numpy_to_torch_(extern_data)

# We want to be able to calculate gradients for testing,
# so we need to set requires_grad=True.
for v in extern_data.data.values():
v: Tensor
v.raw_tensor.requires_grad = True

model = get_model(epoch=1, step=0)
rf.init_train_step_run_ctx(train_flag=True, step=0)
train_step(model=model, extern_data=extern_data)
total_loss = rf.get_run_ctx().total_loss()
assert isinstance(total_loss, Tensor) and not total_loss.dims and total_loss.raw_tensor.dtype.is_floating_point
total_loss_v = total_loss.raw_tensor.detach().numpy().item()
print("total loss (for backprop):", total_loss_v)

total_loss.raw_tensor.backward() # test backprop

for k, loss in rf.get_run_ctx().losses.items():
loss_v = loss.get_summed_loss().raw_tensor.detach().cpu().numpy().item()
print(f"loss (summed) {k!r}: {loss_v}")
loss_v = loss.get_mean_loss().raw_tensor.detach().cpu().numpy().item()
print(f"loss (mean) {k!r}: {loss_v}")
inv_norm_factor = loss.get_inv_norm_factor()
if isinstance(inv_norm_factor, Tensor):
inv_norm_factor = inv_norm_factor.raw_tensor.detach().sum().cpu().numpy().item()
print(f"inv_norm_factor {k!r}: {inv_norm_factor}")

return total_loss_v


def _run_model_net_dict_tf(
extern_data: TensorDict, get_model: rf.GetModelFunc, forward_step: rf.StepFunc
) -> TensorDict:
"""run"""
extern_data_raw = extern_data.as_raw_tensor_dict(expected_value_type=numpy.ndarray)
extern_data.reset_content()
Expand Down
27 changes: 26 additions & 1 deletion tests/test_rf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import _setup_test_env # noqa
import returnn.frontend as rf
from returnn.tensor import Tensor, Dim, TensorDict, batch_dim
from rf_utils import run_model
from rf_utils import run_model, run_model_torch_train


# Keep test_linear_direct and test_linear first here to have some very canonical examples.
Expand Down Expand Up @@ -320,3 +320,28 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):

out = run_model(extern_data, lambda *, epoch, step: _Net(), _forward_step)
assert out["a"].raw_tensor == 2 and out["b"].raw_tensor == 5 and out["c"].raw_tensor == 7


def test_loss_normalized():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
extern_data = TensorDict(
{
"data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"),
}
)

# noinspection PyShadowingNames
def _train_step(*, model: rf.Module, extern_data: TensorDict):
model # unused # noqa
x = extern_data["data"]

loss = rf.reduce_sum(x, axis=in_dim) # [B,T]
loss.mark_as_loss("loss", use_normalized_loss=True)

loss_custom_norm = rf.reduce_sum(loss, axis=time_dim) # [B]
loss_custom_norm.mark_as_loss(
"loss_custom_norm", custom_inv_norm_factor=time_dim.get_size_tensor(), use_normalized_loss=True
)

run_model_torch_train(extern_data, lambda *, epoch, step: rf.Module(), _train_step)

0 comments on commit 4216420

Please sign in to comment.