From 2137898546f7c695bdddcfe5f75dee0ec8a9944c Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 6 Jun 2024 17:17:43 -0700 Subject: [PATCH] Fix AQT conv general construction. PiperOrigin-RevId: 641070533 --- chirp/models/efficientnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chirp/models/efficientnet.py b/chirp/models/efficientnet.py index 4e034493..2492544e 100644 --- a/chirp/models/efficientnet.py +++ b/chirp/models/efficientnet.py @@ -23,7 +23,7 @@ from typing import Callable, NamedTuple from aqt.jax.v2 import aqt_conv_general -from aqt.jax.v2 import config as aqt_cfg +from aqt.jax.v2 import aqt_dot_general from chirp.models import layers from flax import linen as nn import flax.typing as flax_typing @@ -133,7 +133,7 @@ class OpSet: head_activation=nn.hard_swish, dot_general=jax.lax.dot_general, conv_general_dilated=aqt_conv_general.make_conv_general_dilated( - aqt_cfg.DotGeneralRaw.make_conv_general_dilated() + aqt_conv_general.conv_general_dilated_make(spatial_dimensions=2) ), ), }