Skip to content

Commit

Permalink
fix(diffusers/models): fix several bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
townwish4git committed Jun 6, 2024
1 parent 326c291 commit d5176fc
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mindone/diffusers/models/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def construct(
f"which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = self.add_time_proj(time_ids.flatten()).to(time_ids.dtype)
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))

add_embeds = ops.concat([text_embeds, time_embeds], axis=-1)
Expand Down
2 changes: 1 addition & 1 deletion mindone/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def get_fourier_embeds_from_boundingbox(embed_dim, box):

batch_size, num_boxes = box.shape[:2]

emb = 100 ** (ops.arange(embed_dim) / embed_dim)
emb = 100 ** (ops.arange(embed_dim).to(dtype=box.dtype) / embed_dim)
emb = emb[None, None, None].to(dtype=box.dtype)
emb = emb * box.unsqueeze(-1)

Expand Down
2 changes: 1 addition & 1 deletion mindone/diffusers/models/unets/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def construct(
copied_cross_attention_kwargs = {}
for k, v in cross_attention_kwargs.items():
if k == "gligen":
copied_cross_attention_kwargs[k] = {"obj": self.position_net(**v)}
copied_cross_attention_kwargs[k] = {"objs": self.position_net(**v)}
else:
copied_cross_attention_kwargs[k] = v
cross_attention_kwargs = copied_cross_attention_kwargs
Expand Down
7 changes: 4 additions & 3 deletions mindone/diffusers/models/unets/unet_motion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ def from_unet2d(

model = cls.from_config(config)

# Move dtype conversion code here to avoid dtype mismatch issues when loading weights
# ensure that the Motion UNet is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)

if not load_weights:
return model

Expand Down Expand Up @@ -481,9 +485,6 @@ def from_unet2d(
if has_motion_adapter:
model.load_motion_modules(motion_adapter)

# ensure that the Motion UNet is the same dtype as the UNet2DConditionModel
model.to(unet.dtype)

return model

def freeze_unet2d_params(self) -> None:
Expand Down

0 comments on commit d5176fc

Please sign in to comment.