Skip to content

Commit

Permalink
fix(diffusers/models): remove additional arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
townwish4git committed Jun 5, 2024
1 parent 871ad0f commit 326c291
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 19 deletions.
22 changes: 7 additions & 15 deletions mindone/diffusers/models/unets/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1510,7 +1510,7 @@ def __init__(
else:
self.downsamplers = None

def construct(self, hidden_states: ms.Tensor, *args, **kwargs) -> ms.Tensor:
def construct(self, hidden_states: ms.Tensor) -> ms.Tensor:
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = attn(hidden_states)
Expand Down Expand Up @@ -1611,8 +1611,6 @@ def construct(
hidden_states: ms.Tensor,
temb: Optional[ms.Tensor] = None,
skip_sample: Optional[ms.Tensor] = None,
*args,
**kwargs,
) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...], ms.Tensor]:
output_states = ()

Expand Down Expand Up @@ -1700,8 +1698,6 @@ def construct(
hidden_states: ms.Tensor,
temb: Optional[ms.Tensor] = None,
skip_sample: Optional[ms.Tensor] = None,
*args,
**kwargs,
) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...], ms.Tensor]:
output_states = ()

Expand Down Expand Up @@ -1787,7 +1783,9 @@ def __init__(
self.has_cross_attention = False

def construct(
self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, **kwargs # *args, **kwargs
self,
hidden_states: ms.Tensor,
temb: Optional[ms.Tensor] = None,
) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]:
output_states = ()

Expand Down Expand Up @@ -1991,7 +1989,9 @@ def __init__(
self.has_cross_attention = False

def construct(
self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, *args, **kwargs
self,
hidden_states: ms.Tensor,
temb: Optional[ms.Tensor] = None,
) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]:
output_states = ()

Expand Down Expand Up @@ -2751,8 +2751,6 @@ def construct(
res_hidden_states_tuple: Tuple[ms.Tensor, ...],
temb: Optional[ms.Tensor] = None,
skip_sample=None,
*args,
**kwargs,
) -> Tuple[ms.Tensor, ms.Tensor]:
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -2863,8 +2861,6 @@ def construct(
res_hidden_states_tuple: Tuple[ms.Tensor, ...],
temb: Optional[ms.Tensor] = None,
skip_sample=None,
*args,
**kwargs,
) -> Tuple[ms.Tensor, ms.Tensor]:
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -2967,8 +2963,6 @@ def construct(
res_hidden_states_tuple: Tuple[ms.Tensor, ...],
temb: Optional[ms.Tensor] = None,
upsample_size: Optional[int] = None,
# *args,
**kwargs,
) -> ms.Tensor:
for resnet in self.resnets:
# pop res hidden states
Expand Down Expand Up @@ -3186,8 +3180,6 @@ def construct(
res_hidden_states_tuple: Tuple[ms.Tensor, ...],
temb: Optional[ms.Tensor] = None,
upsample_size: Optional[int] = None,
*args,
**kwargs,
) -> ms.Tensor:
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
Expand Down
4 changes: 0 additions & 4 deletions mindone/diffusers/models/unets/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,6 @@ def construct(
hidden_states: ms.Tensor,
temb: Optional[ms.Tensor] = None,
num_frames: int = 1,
# *args,
**kwargs,
) -> Union[ms.Tensor, Tuple[ms.Tensor, ...]]:
output_states = ()

Expand Down Expand Up @@ -1404,8 +1402,6 @@ def construct(
temb: Optional[ms.Tensor] = None,
upsample_size=None,
num_frames: int = 1,
# *args,
**kwargs,
) -> ms.Tensor:
is_freeu_enabled = (
getattr(self, "s1", None)
Expand Down

0 comments on commit 326c291

Please sign in to comment.