Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/diffusers into …
Browse files Browse the repository at this point in the history
…main
  • Loading branch information
Skquark committed Feb 24, 2024
2 parents c910d06 + c09bb58 commit 47fe103
Show file tree
Hide file tree
Showing 44 changed files with 890 additions and 404 deletions.
2 changes: 1 addition & 1 deletion docs/source/en/using-diffusers/write_own_pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ Lastly, convert the image to a `PIL.Image` to see your generated image!
```py
>>> image = (image / 2 + 0.5).clamp(0, 1).squeeze()
>>> image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
>>> images = (image * 255).round().astype("uint8")
>>> image = (image * 255).round().astype("uint8")
>>> image = Image.fromarray(image)
>>> image
```
Expand Down
15 changes: 6 additions & 9 deletions examples/community/imagic_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,9 @@ def __call__(
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
alpha (`float`, *optional*, defaults to 1.2):
The interpolation factor between the original and optimized text embeddings. A value closer to 0
will resemble the original input image.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
Expand All @@ -385,16 +386,9 @@ def __call__(
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`.
Expand All @@ -407,6 +401,9 @@ def __call__(
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
Expand Down
2 changes: 1 addition & 1 deletion examples/community/stable_diffusion_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
torch_dtype=torch.float16
).to('cuda:0')
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config)
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
>>> result_img = pipe(ref_image=input_image,
prompt="1girl",
Expand Down
2 changes: 1 addition & 1 deletion examples/community/stable_diffusion_tensorrt_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ def __call__(
"""
self.generator = generator
self.denoising_steps = num_inference_steps
self.guidance_scale = guidance_scale
self._guidance_scale = guidance_scale

# Pre-compute latent input scales and linear multistep coefficients
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
Expand Down
2 changes: 1 addition & 1 deletion examples/community/stable_diffusion_tensorrt_txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ def __call__(
"""
self.generator = generator
self.denoising_steps = num_inference_steps
self.guidance_scale = guidance_scale
self._guidance_scale = guidance_scale

# Pre-compute latent input scales and linear multistep coefficients
self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device)
Expand Down
13 changes: 9 additions & 4 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@


def save_model_card(
repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None
repo_id: str,
images: list = None,
base_model: str = None,
dataset_name: str = None,
repo_folder: str = None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

model_description = f"""
# LoRA text2image fine-tuning - {repo_id}
Expand Down
7 changes: 4 additions & 3 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def save_model_card(
vae_path: str = None,
):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"

model_description = f"""
# Text-to-image finetuning - {repo_id}
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,7 +1192,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device,
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""

# Overrride to properly handle the loading and unloading of the additional text encoder.
# Override to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
embedding = state_dict["string_to_param"]["*"]
else:
raise ValueError(
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
f"Loaded state dictionary is incorrect: {state_dict}. \n\n"
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
" input key."
)
Expand Down
75 changes: 75 additions & 0 deletions src/diffusers/models/unets/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,81 @@ def get_down_block(
raise ValueError(f"{down_block_type} does not exist.")


def get_mid_block(
mid_block_type: str,
temb_channels: int,
in_channels: int,
resnet_eps: float,
resnet_act_fn: str,
resnet_groups: int,
output_scale_factor: float = 1.0,
transformer_layers_per_block: int = 1,
num_attention_heads: Optional[int] = None,
cross_attention_dim: Optional[int] = None,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
mid_block_only_cross_attention: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
resnet_skip_time_act: bool = False,
cross_attention_norm: Optional[str] = None,
attention_head_dim: Optional[int] = 1,
dropout: float = 0.0,
):
if mid_block_type == "UNetMidBlock2DCrossAttn":
return UNetMidBlock2DCrossAttn(
transformer_layers_per_block=transformer_layers_per_block,
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads,
resnet_groups=resnet_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
return UNetMidBlock2DSimpleCrossAttn(
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
cross_attention_dim=cross_attention_dim,
attention_head_dim=attention_head_dim,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
skip_time_act=resnet_skip_time_act,
only_cross_attention=mid_block_only_cross_attention,
cross_attention_norm=cross_attention_norm,
)
elif mid_block_type == "UNetMidBlock2D":
return UNetMidBlock2D(
in_channels=in_channels,
temb_channels=temb_channels,
dropout=dropout,
num_layers=0,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
output_scale_factor=output_scale_factor,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
add_attention=False,
)
elif mid_block_type is None:
return None
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")


def get_up_block(
up_block_type: str,
num_layers: int,
Expand Down
Loading

0 comments on commit 47fe103

Please sign in to comment.