From d5176fc09cf832616952b37ef3697b5bc48681fd Mon Sep 17 00:00:00 2001 From: townwish4git Date: Fri, 7 Jun 2024 03:03:43 +0800 Subject: [PATCH] fix(diffusers/models): fix several bugs --- mindone/diffusers/models/controlnet.py | 2 +- mindone/diffusers/models/embeddings.py | 2 +- mindone/diffusers/models/unets/unet_2d_condition.py | 2 +- mindone/diffusers/models/unets/unet_motion_model.py | 7 ++++--- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mindone/diffusers/models/controlnet.py b/mindone/diffusers/models/controlnet.py index eb09275d3b..8dc8aad95c 100644 --- a/mindone/diffusers/models/controlnet.py +++ b/mindone/diffusers/models/controlnet.py @@ -733,7 +733,7 @@ def construct( f"which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" ) time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = self.add_time_proj(time_ids.flatten()).to(time_ids.dtype) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = ops.concat([text_embeds, time_embeds], axis=-1) diff --git a/mindone/diffusers/models/embeddings.py b/mindone/diffusers/models/embeddings.py index f1f5abe26d..391d961102 100644 --- a/mindone/diffusers/models/embeddings.py +++ b/mindone/diffusers/models/embeddings.py @@ -652,7 +652,7 @@ def get_fourier_embeds_from_boundingbox(embed_dim, box): batch_size, num_boxes = box.shape[:2] - emb = 100 ** (ops.arange(embed_dim) / embed_dim) + emb = 100 ** (ops.arange(embed_dim).to(dtype=box.dtype) / embed_dim) emb = emb[None, None, None].to(dtype=box.dtype) emb = emb * box.unsqueeze(-1) diff --git a/mindone/diffusers/models/unets/unet_2d_condition.py b/mindone/diffusers/models/unets/unet_2d_condition.py index 1fd44dc9f5..eb076c1e8f 100644 --- a/mindone/diffusers/models/unets/unet_2d_condition.py +++ b/mindone/diffusers/models/unets/unet_2d_condition.py @@ -1034,7 +1034,7 @@ def construct( copied_cross_attention_kwargs = {} for k, v in cross_attention_kwargs.items(): if k == "gligen": - copied_cross_attention_kwargs[k] = {"obj": self.position_net(**v)} + copied_cross_attention_kwargs[k] = {"objs": self.position_net(**v)} else: copied_cross_attention_kwargs[k] = v cross_attention_kwargs = copied_cross_attention_kwargs diff --git a/mindone/diffusers/models/unets/unet_motion_model.py b/mindone/diffusers/models/unets/unet_motion_model.py index 3590b663c7..b373022e9a 100644 --- a/mindone/diffusers/models/unets/unet_motion_model.py +++ b/mindone/diffusers/models/unets/unet_motion_model.py @@ -440,6 +440,10 @@ def from_unet2d( model = cls.from_config(config) + # Move dtype conversion code here to avoid dtype mismatch issues when loading weights + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + if not load_weights: return model @@ -481,9 +485,6 @@ def from_unet2d( if has_motion_adapter: model.load_motion_modules(motion_adapter) - # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel - model.to(unet.dtype) - return model def freeze_unet2d_params(self) -> None: