Skip to content

Commit

Permalink
RF init Normal
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 18, 2023
1 parent 66dad40 commit d235f7d
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions returnn/frontend/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .. import frontend as rf


__all__ = ["ParamInit", "ParamInitType", "Normal", "VarianceScaling", "Glorot", "He", "HeNormal", "HeUniform"]


class ParamInit:
"""API for param init"""

Expand All @@ -29,6 +32,40 @@ def __call__(
ParamInitType = Union[Tensor, rf.RawTensorTypes, ParamInit]


class Normal(ParamInit):
def __init__(self, stddev: float, *, truncated: bool = True, dtype: str = None):
self.stddev = stddev
self.truncated = truncated
if dtype is None:
dtype = rf.get_default_float_dtype()
self.dtype = dtype
if self.stddev <= 0.0:
raise ValueError(f"Argument `stddev` must be a positive float. Received: {self.stddev}")

def __call__(
self,
dims: Sequence[Dim],
dtype: str,
*,
sparse_dim: Optional[Dim] = None,
device: Optional[str] = None,
out: Optional[Tensor] = None,
) -> Tensor:
if dtype is None:
dtype = self.dtype
return rf.random(
distribution="truncated_normal" if self.truncated else "normal",
static=True,
dims=dims,
mean=0.0,
stddev=self.stddev,
dtype=dtype,
sparse_dim=sparse_dim,
device=device,
out=out,
)


class VarianceScaling(ParamInit):
"""
Provides a generalized way for initializing weights.
Expand Down

0 comments on commit d235f7d

Please sign in to comment.