From 6df0b866c377d9583a2698aaa30affa37834ed0f Mon Sep 17 00:00:00 2001 From: SimBe195 <37951951+SimBe195@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:39:34 +0100 Subject: [PATCH] Add **kwargs to LayerNormNC constructor (#61) --- i6_models/parts/conformer/norm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/i6_models/parts/conformer/norm.py b/i6_models/parts/conformer/norm.py index d0401e55..c46155fe 100644 --- a/i6_models/parts/conformer/norm.py +++ b/i6_models/parts/conformer/norm.py @@ -9,11 +9,11 @@ class LayerNormNC(nn.LayerNorm): see here: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html """ - def __init__(self, channels: int): + def __init__(self, channels: int, **kwargs): """ :param channels: number of channels for normalization """ - super().__init__(channels) + super().__init__(channels, **kwargs) def forward(self, tensor: torch.Tensor) -> torch.Tensor: """