Skip to content

Commit

Permalink
Re-fix AQT conv construction.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641333993
  • Loading branch information
sdenton4 authored and copybara-github committed Jun 7, 2024
1 parent 1ae4029 commit 64d2f0e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions chirp/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Callable, NamedTuple

from aqt.jax.v2 import aqt_conv_general
from aqt.jax.v2 import aqt_dot_general
from aqt.jax.v2 import config as aqt_cfg # pylint: disable=unused-import
from chirp.models import layers
from flax import linen as nn
import flax.typing as flax_typing
Expand Down Expand Up @@ -133,7 +133,7 @@ class OpSet:
head_activation=nn.hard_swish,
dot_general=jax.lax.dot_general,
conv_general_dilated=aqt_conv_general.make_conv_general_dilated(
aqt_conv_general.conv_general_dilated_make(spatial_dimensions=2)
aqt_cfg.conv_general_dilated_make(spatial_dimensions=2)
),
),
}
Expand Down

0 comments on commit 64d2f0e

Please sign in to comment.