Skip to content

Commit

Permalink
Init dtype and device correctly for OutputBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
DoryanKaced committed Sep 1, 2023
1 parent e91e31e commit 458b702
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class OutputBlock(fl.Chain):

def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
super().__init__(
fl.GroupNorm(channels=320, num_groups=32),
fl.GroupNorm(channels=320, num_groups=32, device=device, dtype=dtype),
fl.SiLU(),
fl.Conv2d(in_channels=320, out_channels=4, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
)
Expand Down

0 comments on commit 458b702

Please sign in to comment.