diff --git a/mindone/diffusers/models/unets/unet_2d_blocks.py b/mindone/diffusers/models/unets/unet_2d_blocks.py index d4d144983e..c2b909465d 100644 --- a/mindone/diffusers/models/unets/unet_2d_blocks.py +++ b/mindone/diffusers/models/unets/unet_2d_blocks.py @@ -1510,7 +1510,7 @@ def __init__( else: self.downsamplers = None - def construct(self, hidden_states: ms.Tensor, *args, **kwargs) -> ms.Tensor: + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: for resnet, attn in zip(self.resnets, self.attentions): hidden_states = resnet(hidden_states, temb=None) hidden_states = attn(hidden_states) @@ -1611,8 +1611,6 @@ def construct( hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, skip_sample: Optional[ms.Tensor] = None, - *args, - **kwargs, ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...], ms.Tensor]: output_states = () @@ -1700,8 +1698,6 @@ def construct( hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, skip_sample: Optional[ms.Tensor] = None, - *args, - **kwargs, ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...], ms.Tensor]: output_states = () @@ -1787,7 +1783,9 @@ def __init__( self.has_cross_attention = False def construct( - self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, **kwargs # *args, **kwargs + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: output_states = () @@ -1991,7 +1989,9 @@ def __init__( self.has_cross_attention = False def construct( - self, hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, *args, **kwargs + self, + hidden_states: ms.Tensor, + temb: Optional[ms.Tensor] = None, ) -> Tuple[ms.Tensor, Tuple[ms.Tensor, ...]]: output_states = () @@ -2751,8 +2751,6 @@ def construct( res_hidden_states_tuple: Tuple[ms.Tensor, ...], temb: Optional[ms.Tensor] = None, skip_sample=None, - *args, - **kwargs, ) -> Tuple[ms.Tensor, ms.Tensor]: for resnet in self.resnets: # pop res hidden states @@ -2863,8 +2861,6 @@ def construct( res_hidden_states_tuple: Tuple[ms.Tensor, ...], temb: Optional[ms.Tensor] = None, skip_sample=None, - *args, - **kwargs, ) -> Tuple[ms.Tensor, ms.Tensor]: for resnet in self.resnets: # pop res hidden states @@ -2967,8 +2963,6 @@ def construct( res_hidden_states_tuple: Tuple[ms.Tensor, ...], temb: Optional[ms.Tensor] = None, upsample_size: Optional[int] = None, - # *args, - **kwargs, ) -> ms.Tensor: for resnet in self.resnets: # pop res hidden states @@ -3186,8 +3180,6 @@ def construct( res_hidden_states_tuple: Tuple[ms.Tensor, ...], temb: Optional[ms.Tensor] = None, upsample_size: Optional[int] = None, - *args, - **kwargs, ) -> ms.Tensor: res_hidden_states_tuple = res_hidden_states_tuple[-1] if res_hidden_states_tuple is not None: diff --git a/mindone/diffusers/models/unets/unet_3d_blocks.py b/mindone/diffusers/models/unets/unet_3d_blocks.py index 7d78aaa11b..b2d63c4c38 100644 --- a/mindone/diffusers/models/unets/unet_3d_blocks.py +++ b/mindone/diffusers/models/unets/unet_3d_blocks.py @@ -985,8 +985,6 @@ def construct( hidden_states: ms.Tensor, temb: Optional[ms.Tensor] = None, num_frames: int = 1, - # *args, - **kwargs, ) -> Union[ms.Tensor, Tuple[ms.Tensor, ...]]: output_states = () @@ -1404,8 +1402,6 @@ def construct( temb: Optional[ms.Tensor] = None, upsample_size=None, num_frames: int = 1, - # *args, - **kwargs, ) -> ms.Tensor: is_freeu_enabled = ( getattr(self, "s1", None)