From dde2533d2994824acb148a2d06c569ac4d97eefb Mon Sep 17 00:00:00 2001 From: Judyxujj Date: Fri, 10 Nov 2023 15:21:39 +0100 Subject: [PATCH] Update i6_models/parts/frontend/generic_frontend.py Co-authored-by: SimBe195 <37951951+SimBe195@users.noreply.github.com> --- i6_models/parts/frontend/generic_frontend.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/i6_models/parts/frontend/generic_frontend.py b/i6_models/parts/frontend/generic_frontend.py index a50111dd..9b95f6c1 100644 --- a/i6_models/parts/frontend/generic_frontend.py +++ b/i6_models/parts/frontend/generic_frontend.py @@ -85,14 +85,9 @@ def check_valid(self): assert len(self.layer_ordering) == num_convs + num_pools + num_activations, "Number of total layers mismatch!" - for kernel_sizes in [self.conv_kernel_sizes, self.pool_kernel_sizes]: - if kernel_sizes is not None: - for kernel_size in kernel_sizes: - if isinstance(kernel_size, int): - assert kernel_size % 2 == 1, "ConformerVGGFrontendV1 only supports odd kernel sizes" - elif isinstance(kernel_size, tuple): - for i in range(len(kernel_size)): - assert kernel_size[i] % 2 == 1, "ConformerVGGFrontendV1 only supports odd kernel sizes" + for kernel_sizes in filter(None, [self.conv_kernel_sizes, self.pool_kernel_sizes]): + for kernel_size in kernel_sizes: + assert all(k % 2 for k in kernel_size), "ConformerVGGFrontendV1 only supports odd kernel sizes" def __post__init__(self): super().__post_init__()