Skip to content

Commit

Permalink
Merge pull request #1187 from kohya-ss/fix-timeemb
Browse files Browse the repository at this point in the history
fix sdxl timestep embedding
  • Loading branch information
kohya-ss authored Mar 15, 2024
2 parents 2d73891 + f811b11 commit 6b1520a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,16 @@ ControlNet-LLLite, a novel method for ControlNet with SDXL, is added. See [docum

## Change History

### Mar 15, 2024 / 2024/3/15: v0.8.5

- Fixed a bug that the value of timestep embedding during SDXL training was incorrect.
- The inference with the generation script is also fixed.
- The impact is unknown, but please update for SDXL training.

- SDXL 学習時の timestep embedding の値が誤っていたのを修正しました。
- 生成スクリプトでの推論時についてもあわせて修正しました。
- 影響の度合いは不明ですが、SDXL の学習時にはアップデートをお願いいたします。

### Feb 24, 2024 / 2024/2/24: v0.8.4

- The log output has been improved. PR [#905](https://github.com/kohya-ss/sd-scripts/pull/905) Thanks to shirayu!
Expand Down
8 changes: 5 additions & 3 deletions library/sdxl_original_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
from torch.nn import functional as F
from einops import rearrange
from .utils import setup_logging

setup_logging()
import logging

logger = logging.getLogger(__name__)

IN_CHANNELS: int = 4
Expand Down Expand Up @@ -1074,7 +1076,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
timesteps = timesteps.expand(x.shape[0])

hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)

Expand Down Expand Up @@ -1132,7 +1134,7 @@ def __init__(self, original_unet: SdxlUNet2DConditionModel, **kwargs):
# call original model's methods
def __getattr__(self, name):
return getattr(self.delegate, name)

def __call__(self, *args, **kwargs):
return self.delegate(*args, **kwargs)

Expand Down Expand Up @@ -1164,7 +1166,7 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
timesteps = timesteps.expand(x.shape[0])

hs = []
t_emb = get_timestep_embedding(timesteps, _self.model_channels) # , repeat_only=False)
t_emb = get_timestep_embedding(timesteps, _self.model_channels, downscale_freq_shift=0) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = _self.time_embed(t_emb)

Expand Down

0 comments on commit 6b1520a

Please sign in to comment.