Skip to content

Commit

Permalink
Fix AQT conv general construction.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 641070533
  • Loading branch information
sdenton4 authored and copybara-github committed Jun 7, 2024
1 parent ea1c5a0 commit 2137898
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions chirp/models/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from typing import Callable, NamedTuple

from aqt.jax.v2 import aqt_conv_general
from aqt.jax.v2 import config as aqt_cfg
from aqt.jax.v2 import aqt_dot_general
from chirp.models import layers
from flax import linen as nn
import flax.typing as flax_typing
Expand Down Expand Up @@ -133,7 +133,7 @@ class OpSet:
head_activation=nn.hard_swish,
dot_general=jax.lax.dot_general,
conv_general_dilated=aqt_conv_general.make_conv_general_dilated(
aqt_cfg.DotGeneralRaw.make_conv_general_dilated()
aqt_conv_general.conv_general_dilated_make(spatial_dimensions=2)
),
),
}
Expand Down

0 comments on commit 2137898

Please sign in to comment.