Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(diffusers/models): add models like autoencoders, transformers, unets and etc. #523

Merged
merged 1 commit into from
Jun 28, 2024

Conversation

townwish4git
Copy link
Contributor

@townwish4git townwish4git commented Jun 3, 2024

What does this PR do?

Adds # (feature)

implements of models in mindone.diffusers.models, include:

Model

AutoEncoders

  • AutoencoderKL
  • AsymmetricAutoencoderKL
  • AutoencoderKLTemporalDecoder
  • ConsistencyDecoderVAE
  • AutoEncoderTiny
  • VQModel

UNets

  • UNet1DModel
  • UNet2DModel
  • UNet2DConditionModel
  • UNet3DConditionModel
  • I2VGenXLUNet
  • Kandinsky3UNet
  • UNetSpatioTemporalConditionModel
  • UNetMotionModel
  • StableCascadeUNet
  • UViT2DModel

Transformers

  • Transformer2DModel
  • TransformerTemporalModel
  • T5FilmDecoder
  • PriorTransformer
  • DualTransformer2DModel

supplement: https://gist.github.com/Cui-yshoho/7ff86a76323c37f1c197d7fef67702ec

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? E.g. record bug fixes or new features in What's New. Here are the
    documentation guidelines
  • Did you build and run the code without any errors?
  • Did you report the running environment (NPU type/MS version) and performance in the doc? (better record it for data loading, model inference, or training tasks)
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@xxx

Copy link
Collaborator

@geniuspatrick geniuspatrick left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yushi,huiyu合入后,rebase, resolve conflict

@townwish4git
Copy link
Contributor Author

yushi,huiyu合入后,rebase, resolve conflict

done.

mindone/diffusers/README.md Outdated Show resolved Hide resolved
@@ -134,6 +150,39 @@ Most base, utility and mixin class are available.
Unlike the output `posterior = DiagonalGaussianDistribution(latent)`, which can do sampling by `posterior.sample()`.
We can only output the `latent` and then do sampling through `AutoencoderKL.diag_gauss_dist.sample(latent)`.

### `nn.Conv3d`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分删掉,不用列出

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前说的不同精度和模式下的支持列表,先放在pr的说明中

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: 后续考虑和yushi的列表和README中的limitation章节合并,提取出一个Limitation.md以作说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除该部分。yushi的列表下方我也贴了nn.Conv3d影响情况的,我把它贴到这边

@@ -430,7 +430,7 @@ def get_attention_scores(self, query: ms.Tensor, key: ms.Tensor, attention_mask:
)
else:
attention_scores = ops.baddbmm(
attention_mask,
attention_mask.to(query.dtype),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个为什么要加cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ops.baddbmm内部会检查attention_mask, query, key三者的dtype是否一致。在upcast_attention的情况下前者与后两者不一致,会报错

@@ -475,7 +475,9 @@ def prepare_attention_mask(
# we want to instead pad by (0, remaining_length), where remaining_length is:
# remaining_length: int = target_length - current_length
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
attention_mask = ops.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = ops.Pad(paddings=((0, 0),) * (attention_mask.ndim - 1) + ((0, target_length),))(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要手动* (attention_mask.ndim - 1)吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def construct(self, x: ms.Tensor, mask=None) -> ms.Tensor:
r"""The forward method of the `MaskConditionEncoder` class."""
out = {}
for l in range(len(self.layers)): # noqa: E741
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why noqa?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单独一个小字母的变量l不符合规范,为保持变量名和diffusers一致因此加了noqa

self.legacy = legacy

self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
self.embedding.embedding_table.set_data(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: init method refactor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, review plz


def construct(self, input: ms.Tensor) -> ms.Tensor:
return (
0.5 * input * (1.0 + ops.tanh(float(math.sqrt(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

construct里面使用math是否会有问题?考虑将常量手动放在外面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Graph mode下对math.xxx()return结果的类型判断会存在问题,因此用float()包了一层让其能被判断为python内置类型。gitee mindspore某个issue提过这个问题和这个解法,具体链接找不到了

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generic和transformers为什么分开?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# In cases where models have unique initialization procedures or require testing with specialized output formats,
# it is necessary to develop distinct, dedicated test cases.

t5_film_decoder的输入输出格式不符合通用的测试用例结构,因此单独拉出来了。🤔可以考虑把t5的测试文件命名更清晰化,例如就叫test_t5_film_decoder.py

@townwish4git townwish4git force-pushed the 0529models branch 2 times, most recently from 2a6bbde to 60b2ad2 Compare June 20, 2024 06:42
multiplications and allows for post-hoc remapping of indices.
"""

# NOTE: due to a bug the beta term was applied to the wrong term. for
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment comes from origin diffusers, not mindspore version 233


def construct(self, input: ms.Tensor) -> ms.Tensor:
return (
0.5 * input * (1.0 + ops.tanh(float(math.sqrt(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的type预期是什么?和input一致?动态图下这个的type推导可能有问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的type预期是什么?和input一致?动态图下这个的type推导可能有问题

预期和input一致。在什么情况下会有问题呢?目前看起来这里的逻辑没啥问题?

# python-built-in-float * ms.Tensor(dtype=...) => ms.Tensor(dtype=same...)
float(math.sqrt(2.0 / math.pi)) * (input + 0.044715 * ops.pow(input, 3.0))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前 {动态图, 静态图} x {fp16, fp32} 的单元测试都能跑通且和torch精度一致,或许bf16会有问题?之前确实遇到过 ms.Tensor(bf16) * python-built-in-float => ms.Tensor(fp32)的情况😂

@townwish4git
Copy link
Contributor Author

use magic number 0.797885 to replace math.sqrt(2.0 / math.pi), passed ut on cpu in {fp16, fp32} x {pynative, graph} with almost same results comparing to pytorch

@CaitinZhao CaitinZhao added this pull request to the merge queue Jun 28, 2024
Merged via the queue into mindspore-lab:master with commit f4328d1 Jun 28, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants