Skip to content

Commit

Permalink
add slurm files
Browse files Browse the repository at this point in the history
add differentiable decode to SDXL VAE

Optionally return predicted noise during the single step sampling process
* also change  `get_gamma` as a new function to use inside other
  functions which may interact with sampling (e.g. draft+)

debugging sdunet converter script

Added SD/SDXL conversion script from HF to NeMo
* added 'from_nemo' config for VAE

tmp commit, please make changes (oci is super slow, cannot even run vim)

new inference yaml works

add logging to autoencoder

!(dont squash) Added enabling support for LinearWrapper for SDLoRA

added samples_per_batch and fsdp arguments to SDXL inference

added extra optionally wrapper to FSDP
  • Loading branch information
rohitrango committed Jun 25, 2024
1 parent 67011e7 commit 90ad679
Show file tree
Hide file tree
Showing 22 changed files with 814 additions and 53 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*.pkl
#*.ipynb
output
output_2048
result
*.pt
tests/data/asr
Expand Down Expand Up @@ -179,3 +180,4 @@ examples/neural_graphs/*.yml
.hydra/
nemo_experiments/

slurm*.out
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ trainer:
enable_model_summary: True
limit_val_batches: 0


exp_manager:
exp_dir: null
name: ${name}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ model:
lossconfig:
target: torch.nn.Identity



conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ model:
target: torch.nn.Identity



conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ infer:
sampling:
base:
sampler: EulerEDMSampler
width: 256
height: 256
steps: 40
width: 512
height: 512
steps: 50
discretization: "LegacyDDPMDiscretization"
guider: "VanillaCFG"
thresholder: "None"
Expand All @@ -48,8 +48,8 @@ sampling:
s_noise: 1.0
eta: 1.0
order: 4
orig_width: 1024
orig_height: 1024
orig_width: 512
orig_height: 512
crop_coords_top: 0
crop_coords_left: 0
aesthetic_score: 5.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 32
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
gradient_clip_val: 1.0
benchmark: False
enable_model_summary: True
limit_val_batches: 0


infer:
num_samples_per_batch: 1
num_samples: 4
prompt:
- "A professional photograph of an astronaut riding a pig"
- 'A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat.'
- 'A cute corgi lives in a house made out of sushi.'
- 'A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him.'
- 'A brain riding a rocketship heading towards the moon.'
negative_prompt: ""
seed: 123


sampling:
base:
sampler: EulerEDMSampler
width: 512
height: 512
steps: 50
discretization: "LegacyDDPMDiscretization"
guider: "VanillaCFG"
thresholder: "None"
scale: 5.0
img2img_strength: 1.0
sigma_min: 0.0292
sigma_max: 14.6146
rho: 3.0
s_churn: 0.0
s_tmin: 0.0
s_tmax: 999.0
s_noise: 1.0
eta: 1.0
order: 4
orig_width: 512
orig_height: 512
crop_coords_top: 0
crop_coords_left: 0
aesthetic_score: 5.0
negative_aesthetic_score: 5.0

# model:
# is_legacy: False

use_refiner: False
use_fp16: False # use fp16 model weights
out_path: ./output

base_model_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_base.yaml
refiner_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_refiner.yaml

model:
scale_factor: 0.13025
disable_first_stage_autocast: True
is_legacy: False
restore_from_path: ""

fsdp: False
fsdp_set_buffer_dtype: null
fsdp_sharding_strategy: 'full'
use_cpu_initialization: True
# hidden_size: 4
# pipeline_model_parallel_size: 4

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.0
betas:
- 0.9
- 0.999
sched:
name: WarmupHoldPolicy
warmup_steps: 10
hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant

denoiser_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser.DiscreteDenoiser
num_idx: 1000

weighting_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_weighting.EpsWeighting
scaling_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_scaling.EpsScaling
discretization_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.discretizer.LegacyDDPMDiscretization

unet_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel
from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt
from_NeMo: True
adm_in_channels: 2816
num_classes: sequential
use_checkpoint: False
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16
context_dim: 2048
image_size: 64 # unused
# spatial_transformer_attn_type: softmax #note: only default softmax is supported now
legacy: False
use_flash_attention: False

first_stage_config:
# _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper
_target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper
from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt
from_NeMo: True
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla
double_z: true
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
lossconfig:
target: torch.nn.Identity

conditioner_config:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner
emb_models:
# crossattn cond
- is_trainable: False
input_key: txt
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder
layer: hidden
layer_idx: 11
# crossattn and vector cond
- is_trainable: False
input_key: txt
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2
arch: ViT-bigG-14
version: laion2b_s39b_b160k
freeze: True
layer: penultimate
always_return_pooled: True
legacy: False
# vector cond
- is_trainable: False
input_key: original_size_as_tuple
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: crop_coords_top_left
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two
# vector cond
- is_trainable: False
input_key: target_size_as_tuple
emb_model:
_target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND
outdim: 256 # multiplied by two

Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,7 @@ def main(cfg) -> None:
model.zero_grad()

if cfg.model.get('peft', None):

peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]

if cfg.model.peft.restore_from_path is not None:
# initialize peft weights from a checkpoint instead of randomly
# This is not the same as resume training because optimizer states are not restored.
Expand Down
44 changes: 28 additions & 16 deletions examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,44 @@ def model_cfg_modifier(model_cfg):
model_cfg.precision = cfg.trainer.precision
model_cfg.ckpt_path = None
model_cfg.inductor = False
model_cfg.unet_config.from_pretrained = None
model_cfg.first_stage_config.from_pretrained = None
model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt"
model_cfg.unet_config.from_NeMo = True
model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt"
model_cfg.first_stage_config.from_NeMo = True
model_cfg.first_stage_config._target_ = 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper'
model_cfg.fsdp = False
# model_cfg.fsdp = True

torch.backends.cuda.matmul.allow_tf32 = True
trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference(
model_provider=MegatronDiffusionEngine, cfg=cfg, model_cfg_modifier=model_cfg_modifier
)

### Manually configure sharded model
# model = megatron_diffusion_model
# model = trainer.strategy._setup_model(model)
# model = model.cuda(torch.cuda.current_device())
# get the diffusion part only
model = megatron_diffusion_model.model
model.cuda().eval()

base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy)
use_refiner = cfg.get('use_refiner', False)
for i, prompt in enumerate(cfg.infer.prompt):
samples = base.text_to_image(
params=cfg.sampling.base,
prompt=[prompt],
negative_prompt=cfg.infer.negative_prompt,
samples=cfg.infer.num_samples,
return_latents=True if use_refiner else False,
seed=int(cfg.infer.seed + i * 100),
)

perform_save_locally(cfg.out_path, samples)
with torch.no_grad():
base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy)
use_refiner = cfg.get('use_refiner', False)
num_samples_per_batch = cfg.infer.get('num_samples_per_batch', cfg.infer.num_samples)
num_batches = cfg.infer.num_samples // num_samples_per_batch

for i, prompt in enumerate(cfg.infer.prompt):
for batchid in range(num_batches):
samples = base.text_to_image(
params=cfg.sampling.base,
prompt=[prompt],
negative_prompt=cfg.infer.negative_prompt,
samples=num_samples_per_batch,
return_latents=True if use_refiner else False,
seed=int(cfg.infer.seed + i * 100 + batchid * 200),
)
# samples=cfg.infer.num_samples,
perform_save_locally(cfg.out_path, samples)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def main(cfg) -> None:
model = MegatronDiffusionEngine(cfg.model, trainer)

if cfg.model.get('peft', None):

peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme]

if cfg.model.peft.restore_from_path is not None:
# initialize peft weights from a checkpoint instead of randomly
# This is not the same as resume training because optimizer states are not restored.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,13 @@ def decode_first_stage(self, z):
out = self.first_stage_model.decode(z)
return out

# same as above but differentiable
def differentiable_decode_first_stage(self, z):
z = 1.0 / self.scale_factor * z
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
out = self.first_stage_model.decode(z)
return out

@torch.no_grad()
def encode_first_stage(self, x):
with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from nemo.utils import logging

try:
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
Expand Down Expand Up @@ -316,6 +317,7 @@ def __init__(
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
from_NeMo=False,
monitor=None,
from_pretrained: str = None,
):
Expand All @@ -337,6 +339,7 @@ def __init__(
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)

if from_pretrained is not None:
logging.info(f"Attempting to load vae weights from {from_pretrained}")
if from_pretrained.endswith('safetensors'):
from safetensors.torch import load_file as load_safetensors

Expand All @@ -345,7 +348,7 @@ def __init__(
state_dict = torch.load(from_pretrained)
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict)
missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo)
if len(missing_key) > 0:
print(
f'{self.__class__.__name__}: Following keys are missing during loading VAE weights, which may lead to compromised image quality for a resumed training. Please check the checkpoint you provided.'
Expand All @@ -355,7 +358,6 @@ def __init__(

def _state_key_mapping(self, state_dict: dict):
import re

res_dict = {}
key_list = state_dict.keys()
key_str = " ".join(key_list)
Expand Down Expand Up @@ -395,8 +397,9 @@ def _state_key_mapping(self, state_dict: dict):
res_dict[key_] = val_
return res_dict

def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False):
state_dict = self._state_key_mapping(state_dict)
def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo = False):
if not from_NeMo:
state_dict = self._state_key_mapping(state_dict)
model_state_dict = self.state_dict()
loaded_keys = [k for k in state_dict.keys()]
expected_keys = list(model_state_dict.keys())
Expand Down
Loading

0 comments on commit 90ad679

Please sign in to comment.