From 1685670145047ac59b215fa5b488b0746958e3fa Mon Sep 17 00:00:00 2001 From: nateraw Date: Mon, 5 Dec 2022 17:25:01 +0000 Subject: [PATCH 1/5] :pushpin: move realesrgan to required deps --- requirements.txt | 5 +++-- setup.py | 4 ---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index 4a18f3f..b884399 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -transformers -diffusers==0.6.0 +transformers>=4.21.0 +diffusers==0.9.0 scipy fire gradio librosa av<10.0.0 +realesrgan==0.2.5.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 5994a9b..5813f7f 100644 --- a/setup.py +++ b/setup.py @@ -14,9 +14,6 @@ def get_version() -> str: with open("requirements.txt", "r") as f: requirements = f.read().splitlines() -extras = {} -extras['realesrgan'] = ['realesrgan==0.2.5.0'] - setup( name="stable_diffusion_videos", version=get_version(), @@ -29,6 +26,5 @@ def get_version() -> str: long_description_content_type="text/markdown", license="Apache", install_requires=requirements, - extras_require=extras, packages=find_packages(), ) From de80244a1213af1630bffbfe4fc671be43da7185 Mon Sep 17 00:00:00 2001 From: nateraw Date: Mon, 5 Dec 2022 17:26:05 +0000 Subject: [PATCH 2/5] :sparkles: add negative prompt, some minor 0.9.0 updates --- .../stable_diffusion_pipeline.py | 92 +++++++++++++++++-- 1 file changed, 82 insertions(+), 10 deletions(-) diff --git a/stable_diffusion_videos/stable_diffusion_pipeline.py b/stable_diffusion_videos/stable_diffusion_pipeline.py index c379aac..9d0e32f 100644 --- a/stable_diffusion_videos/stable_diffusion_pipeline.py +++ b/stable_diffusion_videos/stable_diffusion_pipeline.py @@ -10,12 +10,20 @@ import json import torch +from packaging import version from diffusers.configuration_utils import FrozenDict from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.utils import deprecate, logging -from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from diffusers.schedulers import ( + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -166,9 +174,17 @@ def __init__( text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, ): super().__init__() @@ -186,8 +202,21 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) - if safety_checker is None: - logger.warn( + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" @@ -196,6 +225,33 @@ def __init__( " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -205,6 +261,9 @@ def __init__( safety_checker=safety_checker, feature_extractor=feature_extractor, ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" @@ -218,9 +277,14 @@ def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto `attention_head_dim` must be a multiple of `slice_size`. """ if slice_size == "auto": - # half the attention head size is usually a good trade-off between - # speed and memory - slice_size = self.unet.config.attention_head_dim // 2 + if isinstance(self.unet.config.attention_head_dim, int): + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = self.unet.config.attention_head_dim // 2 + else: + # if `attention_head_dim` is a list, take the smallest head size + slice_size = min(self.unet.config.attention_head_dim) + self.unet.set_attention_slice(slice_size) def disable_attention_slicing(self): @@ -361,7 +425,7 @@ def __call__( uncond_tokens: List[str] if negative_prompt is None: uncond_tokens = [""] - elif type(prompt) is not type(negative_prompt): + elif text_embeddings is None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." @@ -524,6 +588,7 @@ def make_clip_frames( image_file_ext: str = ".png", T: np.ndarray = None, skip: int = 0, + negative_prompt: str = None, ): save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) @@ -559,6 +624,7 @@ def make_clip_frames( eta=eta, num_inference_steps=num_inference_steps, output_type="pil" if not upsample else "numpy", + negative_prompt=negative_prompt, )["images"] for image in outputs: @@ -588,6 +654,7 @@ def walk( audio_start_sec: Optional[Union[int, float]] = None, margin: Optional[float] = 1.0, smooth: Optional[float] = 0.0, + negative_prompt: Optional[str] = None, ): """Generate a video from a sequence of prompts and seeds. Optionally, add audio to the video to interpolate to the intensity of the audio. @@ -638,6 +705,8 @@ def walk( Margin from librosa hpss to use for audio interpolation. smooth (Optional[float], *optional*, defaults to 0.0): Smoothness of the audio interpolation. 1.0 means linear interpolation. + negative_prompt (Optional[str], *optional*, defaults to None): + Optional negative prompt to use. Same across all prompts. This function will create sub directories for each prompt and seed pair. @@ -710,6 +779,7 @@ def walk( width=width, audio_filepath=audio_filepath, audio_start_sec=audio_start_sec, + negative_prompt=negative_prompt, ), indent=2, sort_keys=False, @@ -729,6 +799,7 @@ def walk( width = data["width"] audio_filepath = data["audio_filepath"] audio_start_sec = data["audio_start_sec"] + negative_prompt = data.get("negative_prompt", None) for i, (prompt_a, prompt_b, seed_a, seed_b, num_step) in enumerate( zip(prompts, prompts[1:], seeds, seeds[1:], num_interpolation_steps) @@ -771,7 +842,6 @@ def walk( width=width, upsample=upsample, batch_size=batch_size, - skip=skip, T=get_timesteps_arr( audio_filepath, offset=audio_offset, @@ -782,6 +852,8 @@ def walk( ) if audio_filepath else None, + skip=skip, + negative_prompt=negative_prompt, ) make_video_pyav( save_path, @@ -805,7 +877,7 @@ def walk( sr=44100, ) - def embed_text(self, text): + def embed_text(self, text, negative_prompt=None): """Helper to embed some text""" with torch.autocast("cuda"): text_input = self.tokenizer( From e5d9c2f8ab253ebe58a640f27d624254b642cabd Mon Sep 17 00:00:00 2001 From: nateraw Date: Mon, 5 Dec 2022 17:26:24 +0000 Subject: [PATCH 3/5] :see_no_evil: add music files to gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ae91fbf..afb9846 100644 --- a/.gitignore +++ b/.gitignore @@ -132,4 +132,5 @@ dmypy.json dreams images run.py -test_outputs \ No newline at end of file +test_outputs +examples/music \ No newline at end of file From 420f87b0f5efad1b59f4045c7887347eae27686c Mon Sep 17 00:00:00 2001 From: nateraw Date: Mon, 5 Dec 2022 17:27:03 +0000 Subject: [PATCH 4/5] :pushpin: bump package version to 0.7.0dev --- stable_diffusion_videos/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stable_diffusion_videos/__init__.py b/stable_diffusion_videos/__init__.py index a8517a8..b0f4870 100644 --- a/stable_diffusion_videos/__init__.py +++ b/stable_diffusion_videos/__init__.py @@ -114,4 +114,4 @@ def __dir__(): }, ) -__version__ = "0.6.2" +__version__ = "0.7.0dev" From c11d154375e3ab7e3919f73aaa9fed1d64ab4125 Mon Sep 17 00:00:00 2001 From: nateraw Date: Mon, 5 Dec 2022 17:50:58 +0000 Subject: [PATCH 5/5] :pushpin: update pinned version and update readme --- README.md | 8 ++------ stable_diffusion_videos/__init__.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index eb72c63..526ce88 100644 --- a/README.md +++ b/README.md @@ -137,13 +137,9 @@ Enjoy 🤗 You can also 4x upsample your images with [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN)! -First, you'll need to install it... +It's included when you pip install the latest version of `stable-diffusion-videos`! -```bash -pip install realesrgan -``` - -Then, you'll be able to use `upsample=True` in the `walk` function, like this: +You'll be able to use `upsample=True` in the `walk` function, like this: ```python pipeline.walk(['a cat', 'a dog'], [234, 345], upsample=True) diff --git a/stable_diffusion_videos/__init__.py b/stable_diffusion_videos/__init__.py index b0f4870..8943617 100644 --- a/stable_diffusion_videos/__init__.py +++ b/stable_diffusion_videos/__init__.py @@ -114,4 +114,4 @@ def __dir__(): }, ) -__version__ = "0.7.0dev" +__version__ = "0.7.0"