From 64d2f0ef9571ac75fe4c9e7dad0d754503de7707 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Fri, 7 Jun 2024 13:02:01 -0700 Subject: [PATCH] Re-fix AQT conv construction. PiperOrigin-RevId: 641333993 --- chirp/models/efficientnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chirp/models/efficientnet.py b/chirp/models/efficientnet.py index 2492544e..723e74cb 100644 --- a/chirp/models/efficientnet.py +++ b/chirp/models/efficientnet.py @@ -23,7 +23,7 @@ from typing import Callable, NamedTuple from aqt.jax.v2 import aqt_conv_general -from aqt.jax.v2 import aqt_dot_general +from aqt.jax.v2 import config as aqt_cfg # pylint: disable=unused-import from chirp.models import layers from flax import linen as nn import flax.typing as flax_typing @@ -133,7 +133,7 @@ class OpSet: head_activation=nn.hard_swish, dot_general=jax.lax.dot_general, conv_general_dilated=aqt_conv_general.make_conv_general_dilated( - aqt_conv_general.conv_general_dilated_make(spatial_dimensions=2) + aqt_cfg.conv_general_dilated_make(spatial_dimensions=2) ), ), }