From d235f7d2120e14855ecf0face6c6e6d1eb027ba8 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Sat, 18 Nov 2023 23:46:07 +0000 Subject: [PATCH] RF init Normal --- returnn/frontend/init.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/returnn/frontend/init.py b/returnn/frontend/init.py index 81f66b432b..2ffa08f160 100644 --- a/returnn/frontend/init.py +++ b/returnn/frontend/init.py @@ -11,6 +11,9 @@ from .. import frontend as rf +__all__ = ["ParamInit", "ParamInitType", "Normal", "VarianceScaling", "Glorot", "He", "HeNormal", "HeUniform"] + + class ParamInit: """API for param init""" @@ -29,6 +32,40 @@ def __call__( ParamInitType = Union[Tensor, rf.RawTensorTypes, ParamInit] +class Normal(ParamInit): + def __init__(self, stddev: float, *, truncated: bool = True, dtype: str = None): + self.stddev = stddev + self.truncated = truncated + if dtype is None: + dtype = rf.get_default_float_dtype() + self.dtype = dtype + if self.stddev <= 0.0: + raise ValueError(f"Argument `stddev` must be a positive float. Received: {self.stddev}") + + def __call__( + self, + dims: Sequence[Dim], + dtype: str, + *, + sparse_dim: Optional[Dim] = None, + device: Optional[str] = None, + out: Optional[Tensor] = None, + ) -> Tensor: + if dtype is None: + dtype = self.dtype + return rf.random( + distribution="truncated_normal" if self.truncated else "normal", + static=True, + dims=dims, + mean=0.0, + stddev=self.stddev, + dtype=dtype, + sparse_dim=sparse_dim, + device=device, + out=out, + ) + + class VarianceScaling(ParamInit): """ Provides a generalized way for initializing weights.