Skip to content

Commit

Permalink
fix(diffusers): fix the bug of sdxl training
Browse files Browse the repository at this point in the history
  • Loading branch information
The-truthh committed Nov 7, 2024
1 parent 29ca665 commit 18955da
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions mindone/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,23 +120,24 @@ def register_to_config(self, **kwargs):

self._internal_dict = FrozenDict(internal_dict)

def __getattr__(self, name: str) -> Any:
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
This function is mostly copied from PyTorch's __getattr__ overwrite:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
"""

is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
is_attribute = name in self.__dict__

if is_in_config and not is_attribute:
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." # noqa: E501
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
return self._internal_dict[name]

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
# TODO: The __getattr__ function will fail in graph mode during training. It needs to be fixed later.
# def __getattr__(self, name: str) -> Any:
# """The only reason we overwrite `getattr` here is to gracefully deprecate accessing
# config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
#
# This function is mostly copied from PyTorch's __getattr__ overwrite:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
# """
#
# is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
# is_attribute = name in self.__dict__
#
# if is_in_config and not is_attribute:
# deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'." # noqa: E501
# deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
# return self._internal_dict[name]
#
# raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Expand Down

0 comments on commit 18955da

Please sign in to comment.