Skip to content

Commit

Permalink
remove triton dependency of DC-AE and fix bugs; (NVlabs#38)
Browse files Browse the repository at this point in the history
Signed-off-by: lawrence-cj <[email protected]>
Co-authored-by: Junyu Chen <[email protected]>
  • Loading branch information
lawrence-cj and chenjy2003 authored Nov 24, 2024
1 parent fa267d5 commit 2a50a7f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class EncoderConfig:
width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024)
depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2)
block_type: Any = "ResBlock"
norm: str = "trms2d"
norm: str = "rms2d"
act: str = "silu"
downsample_block_type: str = "ConvPixelUnshuffle"
downsample_match_channel: bool = True
Expand All @@ -67,12 +67,12 @@ class DecoderConfig:
width_list: tuple[int, ...] = (128, 256, 512, 512, 1024, 1024)
depth_list: tuple[int, ...] = (2, 2, 2, 2, 2, 2)
block_type: Any = "ResBlock"
norm: Any = "trms2d"
norm: Any = "rms2d"
act: Any = "silu"
upsample_block_type: str = "ConvPixelShuffle"
upsample_match_channel: bool = True
upsample_shortcut: str = "duplicating"
out_norm: str = "trms2d"
out_norm: str = "rms2d"
out_act: str = "relu"


Expand Down Expand Up @@ -470,7 +470,7 @@ def dc_ae_f32c32(name: str, pretrained_path: str) -> DCAEConfig:
"decoder.block_type=[ResBlock,ResBlock,ResBlock,EViTS5_GLU,EViTS5_GLU,EViTS5_GLU] "
"decoder.width_list=[128,256,512,512,1024,1024] decoder.depth_list=[3,3,3,3,3,3] "
"decoder.upsample_block_type=InterpolateConv "
"decoder.norm=trms2d decoder.act=silu "
"decoder.norm=rms2d decoder.act=silu "
"scaling_factor=0.41407"
)
else:
Expand Down
26 changes: 26 additions & 0 deletions diffusion/model/dc_ae/efficientvit/models/nn/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return TritonRMSNorm2dFunc.apply(x, self.weight, self.bias, self.eps)


class RMSNorm2d(nn.Module):
def __init__(
self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True
) -> None:
super().__init__()
self.num_features = num_features
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features))
if bias:
self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
if self.elementwise_affine:
x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
return x


# register normalization function here
REGISTERED_NORM_DICT: dict[str, type] = {
"bn2d": nn.BatchNorm2d,
"ln": nn.LayerNorm,
"ln2d": LayerNorm2d,
"trms2d": TritonRMSNorm2d,
"rms2d": RMSNorm2d,
}


Expand Down

0 comments on commit 2a50a7f

Please sign in to comment.