From 46940cb64b50072da629b66a13707242373bd99e Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 6 Nov 2023 19:01:02 +0100 Subject: [PATCH 1/4] Drop diffusion for XTTS --- TTS/tts/layers/tortoise/dpm_solver.py | 1551 ------------------ TTS/tts/layers/xtts/diffusion.py | 1319 --------------- TTS/tts/models/xtts.py | 17 - recipes/ljspeech/xtts_v1/train_gpt_xtts.py | 2 +- recipes/ljspeech/xtts_v2/train_gpt_xtts.py | 2 +- tests/xtts_tests/test_xtts_gpt_train.py | 2 +- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 2 +- 7 files changed, 4 insertions(+), 2891 deletions(-) delete mode 100644 TTS/tts/layers/tortoise/dpm_solver.py delete mode 100644 TTS/tts/layers/xtts/diffusion.py diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py deleted file mode 100644 index cb540577f8..0000000000 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ /dev/null @@ -1,1551 +0,0 @@ -import math - -import torch - - -class NoiseScheduleVP: - def __init__( - self, - schedule="discrete", - betas=None, - alphas_cumprod=None, - continuous_beta_0=0.1, - continuous_beta_1=20.0, - dtype=torch.float32, - ): - """Create a wrapper class for the forward SDE (VP type). - - *** - Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. - We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. - *** - - The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). - We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). - Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: - - log_alpha_t = self.marginal_log_mean_coeff(t) - sigma_t = self.marginal_std(t) - lambda_t = self.marginal_lambda(t) - - Moreover, as lambda(t) is an invertible function, we also support its inverse function: - - t = self.inverse_lambda(lambda_t) - - =============================================================== - - We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). - - 1. For discrete-time DPMs: - - For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: - t_i = (i + 1) / N - e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. - We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. - - Args: - betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) - alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) - - Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. - - **Important**: Please pay special attention for the args for `alphas_cumprod`: - The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that - q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). - Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have - alpha_{t_n} = \sqrt{\hat{alpha_n}}, - and - log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). - - - 2. For continuous-time DPMs: - - We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise - schedule are the default settings in DDPM and improved-DDPM: - - Args: - beta_min: A `float` number. The smallest beta for the linear schedule. - beta_max: A `float` number. The largest beta for the linear schedule. - cosine_s: A `float` number. The hyperparameter in the cosine schedule. - cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. - T: A `float` number. The ending time of the forward process. - - =============================================================== - - Args: - schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, - 'linear' or 'cosine' for continuous-time DPMs. - Returns: - A wrapper object of the forward SDE (VP type). - - =============================================================== - - Example: - - # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', betas=betas) - - # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): - >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) - - # For continuous-time DPMs (VPSDE), linear schedule: - >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) - - """ - - if schedule not in ["discrete", "linear", "cosine"]: - raise ValueError( - "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( - schedule - ) - ) - - self.schedule = schedule - if schedule == "discrete": - if betas is not None: - log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) - else: - assert alphas_cumprod is not None - log_alphas = 0.5 * torch.log(alphas_cumprod) - self.total_N = len(log_alphas) - self.T = 1.0 - self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) - self.log_alpha_array = log_alphas.reshape( - ( - 1, - -1, - ) - ).to(dtype=dtype) - else: - self.total_N = 1000 - self.beta_0 = continuous_beta_0 - self.beta_1 = continuous_beta_1 - self.cosine_s = 0.008 - self.cosine_beta_max = 999.0 - self.cosine_t_max = ( - math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) - * 2.0 - * (1.0 + self.cosine_s) - / math.pi - - self.cosine_s - ) - self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) - self.schedule = schedule - if schedule == "cosine": - # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. - # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. - self.T = 0.9946 - else: - self.T = 1.0 - - def marginal_log_mean_coeff(self, t): - """ - Compute log(alpha_t) of a given continuous-time label t in [0, T]. - """ - if self.schedule == "discrete": - return interpolate_fn( - t.reshape((-1, 1)), - self.t_array.to(t.device), - self.log_alpha_array.to(t.device), - ).reshape((-1)) - elif self.schedule == "linear": - return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 - elif self.schedule == "cosine": - - def log_alpha_fn(s): - return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) - - log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 - return log_alpha_t - - def marginal_alpha(self, t): - """ - Compute alpha_t of a given continuous-time label t in [0, T]. - """ - return torch.exp(self.marginal_log_mean_coeff(t)) - - def marginal_std(self, t): - """ - Compute sigma_t of a given continuous-time label t in [0, T]. - """ - return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) - - def marginal_lambda(self, t): - """ - Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. - """ - log_mean_coeff = self.marginal_log_mean_coeff(t) - log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) - return log_mean_coeff - log_std - - def inverse_lambda(self, lamb): - """ - Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. - """ - if self.schedule == "linear": - tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) - Delta = self.beta_0**2 + tmp - return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) - elif self.schedule == "discrete": - log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) - t = interpolate_fn( - log_alpha.reshape((-1, 1)), - torch.flip(self.log_alpha_array.to(lamb.device), [1]), - torch.flip(self.t_array.to(lamb.device), [1]), - ) - return t.reshape((-1,)) - else: - log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) - - def t_fn(log_alpha_t): - return ( - torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) - * 2.0 - * (1.0 + self.cosine_s) - / math.pi - - self.cosine_s - ) - - t = t_fn(log_alpha) - return t - - -def model_wrapper( - model, - noise_schedule, - model_type="noise", - model_kwargs={}, - guidance_type="uncond", - condition=None, - unconditional_condition=None, - guidance_scale=1.0, - classifier_fn=None, - classifier_kwargs={}, -): - """Create a wrapper function for the noise prediction model. - - DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to - firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. - - We support four types of the diffusion model by setting `model_type`: - - 1. "noise": noise prediction model. (Trained by predicting noise). - - 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). - - 3. "v": velocity prediction model. (Trained by predicting the velocity). - The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. - - [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." - arXiv preprint arXiv:2202.00512 (2022). - [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." - arXiv preprint arXiv:2210.02303 (2022). - - 4. "score": marginal score function. (Trained by denoising score matching). - Note that the score function and the noise prediction model follows a simple relationship: - ``` - noise(x_t, t) = -sigma_t * score(x_t, t) - ``` - - We support three types of guided sampling by DPMs by setting `guidance_type`: - 1. "uncond": unconditional sampling by DPMs. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. - The input `model` has the following format: - `` - model(x, t_input, **model_kwargs) -> noise | x_start | v | score - `` - - The input `classifier_fn` has the following format: - `` - classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) - `` - - [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," - in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. - - 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. - The input `model` has the following format: - `` - model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score - `` - And if cond == `unconditional_condition`, the model output is the unconditional DPM output. - - [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." - arXiv preprint arXiv:2207.12598 (2022). - - - The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) - or continuous-time labels (i.e. epsilon to T). - - We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: - `` - def model_fn(x, t_continuous) -> noise: - t_input = get_model_input_time(t_continuous) - return noise_pred(model, x, t_input, **model_kwargs) - `` - where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. - - =============================================================== - - Args: - model: A diffusion model with the corresponding format described above. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - model_type: A `str`. The parameterization type of the diffusion model. - "noise" or "x_start" or "v" or "score". - model_kwargs: A `dict`. A dict for the other inputs of the model function. - guidance_type: A `str`. The type of the guidance for sampling. - "uncond" or "classifier" or "classifier-free". - condition: A pytorch tensor. The condition for the guided sampling. - Only used for "classifier" or "classifier-free" guidance type. - unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. - Only used for "classifier-free" guidance type. - guidance_scale: A `float`. The scale for the guided sampling. - classifier_fn: A classifier function. Only used for the classifier guidance. - classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. - Returns: - A noise prediction model that accepts the noised data and the continuous time as the inputs. - """ - - def get_model_input_time(t_continuous): - """ - Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. - For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. - For continuous-time DPMs, we just use `t_continuous`. - """ - if noise_schedule.schedule == "discrete": - return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 - else: - return t_continuous - - def noise_pred_fn(x, t_continuous, cond=None): - t_input = get_model_input_time(t_continuous) - if cond is None: - output = model(x, t_input, **model_kwargs) - else: - output = model(x, t_input, cond, **model_kwargs) - if model_type == "noise": - return output - elif model_type == "x_start": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - return (x - alpha_t * output) / sigma_t - elif model_type == "v": - alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) - return alpha_t * output + sigma_t * x - elif model_type == "score": - sigma_t = noise_schedule.marginal_std(t_continuous) - return -sigma_t * output - - def cond_grad_fn(x, t_input): - """ - Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). - """ - with torch.enable_grad(): - x_in = x.detach().requires_grad_(True) - log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) - return torch.autograd.grad(log_prob.sum(), x_in)[0] - - def model_fn(x, t_continuous): - """ - The noise predicition model function that is used for DPM-Solver. - """ - if guidance_type == "uncond": - return noise_pred_fn(x, t_continuous) - elif guidance_type == "classifier": - assert classifier_fn is not None - t_input = get_model_input_time(t_continuous) - cond_grad = cond_grad_fn(x, t_input) - sigma_t = noise_schedule.marginal_std(t_continuous) - noise = noise_pred_fn(x, t_continuous) - return noise - guidance_scale * sigma_t * cond_grad - elif guidance_type == "classifier-free": - if guidance_scale == 1.0 or unconditional_condition is None: - return noise_pred_fn(x, t_continuous, cond=condition) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - return noise_uncond + guidance_scale * (noise - noise_uncond) - - assert model_type in ["noise", "x_start", "v", "score"] - assert guidance_type in ["uncond", "classifier", "classifier-free"] - return model_fn - - -class DPM_Solver: - def __init__( - self, - model_fn, - noise_schedule, - algorithm_type="dpmsolver++", - correcting_x0_fn=None, - correcting_xt_fn=None, - thresholding_max_val=1.0, - dynamic_thresholding_ratio=0.995, - ): - """Construct a DPM-Solver. - - We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). - - We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you - can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the - dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space - DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space - DPMs (such as stable-diffusion). - - To support advanced algorithms in image-to-image applications, we also support corrector functions for - both x0 and xt. - - Args: - model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): - `` - def model_fn(x, t_continuous): - return noise - `` - The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. - noise_schedule: A noise schedule object, such as NoiseScheduleVP. - algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". - correcting_x0_fn: A `str` or a function with the following format: - ``` - def correcting_x0_fn(x0, t): - x0_new = ... - return x0_new - ``` - This function is to correct the outputs of the data prediction model at each sampling step. e.g., - ``` - x0_pred = data_pred_model(xt, t) - if correcting_x0_fn is not None: - x0_pred = correcting_x0_fn(x0_pred, t) - xt_1 = update(x0_pred, xt, t) - ``` - If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. - correcting_xt_fn: A function with the following format: - ``` - def correcting_xt_fn(xt, t, step): - x_new = ... - return x_new - ``` - This function is to correct the intermediate samples xt at each sampling step. e.g., - ``` - xt = ... - xt = correcting_xt_fn(xt, t, step) - ``` - thresholding_max_val: A `float`. The max value for thresholding. - Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. - dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). - Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. - - [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, - Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models - with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. - """ - self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) - self.noise_schedule = noise_schedule - assert algorithm_type in ["dpmsolver", "dpmsolver++"] - self.algorithm_type = algorithm_type - if correcting_x0_fn == "dynamic_thresholding": - self.correcting_x0_fn = self.dynamic_thresholding_fn - else: - self.correcting_x0_fn = correcting_x0_fn - self.correcting_xt_fn = correcting_xt_fn - self.dynamic_thresholding_ratio = dynamic_thresholding_ratio - self.thresholding_max_val = thresholding_max_val - - def dynamic_thresholding_fn(self, x0, t): - """ - The dynamic thresholding method. - """ - dims = x0.dim() - p = self.dynamic_thresholding_ratio - s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims( - torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), - dims, - ) - x0 = torch.clamp(x0, -s, s) / s - return x0 - - def noise_prediction_fn(self, x, t): - """ - Return the noise prediction model. - """ - return self.model(x, t) - - def data_prediction_fn(self, x, t): - """ - Return the data prediction model (with corrector). - """ - noise = self.noise_prediction_fn(x, t) - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - x0 = (x - sigma_t * noise) / alpha_t - if self.correcting_x0_fn is not None: - x0 = self.correcting_x0_fn(x0, t) - return x0 - - def model_fn(self, x, t): - """ - Convert the model to the noise prediction model or the data prediction model. - """ - if self.algorithm_type == "dpmsolver++": - return self.data_prediction_fn(x, t) - else: - return self.noise_prediction_fn(x, t) - - def get_time_steps(self, skip_type, t_T, t_0, N, device): - """Compute the intermediate time steps for sampling. - - Args: - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - N: A `int`. The total number of the spacing of the time steps. - device: A torch device. - Returns: - A pytorch tensor of the time steps, with the shape (N + 1,). - """ - if skip_type == "logSNR": - lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) - lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) - logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) - return self.noise_schedule.inverse_lambda(logSNR_steps) - elif skip_type == "time_uniform": - return torch.linspace(t_T, t_0, N + 1).to(device) - elif skip_type == "time_quadratic": - t_order = 2 - t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) - return t - else: - raise ValueError( - "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) - ) - - def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): - """ - Get the order of each step for sampling by the singlestep DPM-Solver. - - We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". - Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: - - If order == 1: - We take `steps` of DPM-Solver-1 (i.e. DDIM). - - If order == 2: - - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of DPM-Solver-2. - - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If order == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. - - ============================================ - Args: - order: A `int`. The max order for the solver (2 or 3). - steps: A `int`. The total number of function evaluations (NFE). - skip_type: A `str`. The type for the spacing of the time steps. We support three types: - - 'logSNR': uniform logSNR for the time steps. - - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) - - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - device: A torch device. - Returns: - orders: A list of the solver order of each step. - """ - if order == 3: - K = steps // 3 + 1 - if steps % 3 == 0: - orders = [3,] * ( - K - 2 - ) + [2, 1] - elif steps % 3 == 1: - orders = [3,] * ( - K - 1 - ) + [1] - else: - orders = [3,] * ( - K - 1 - ) + [2] - elif order == 2: - if steps % 2 == 0: - K = steps // 2 - orders = [ - 2, - ] * K - else: - K = steps // 2 + 1 - orders = [2,] * ( - K - 1 - ) + [1] - elif order == 1: - K = 1 - orders = [ - 1, - ] * steps - else: - raise ValueError("'order' must be '1' or '2' or '3'.") - if skip_type == "logSNR": - # To reproduce the results in DPM-Solver paper - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) - else: - timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ - torch.cumsum( - torch.tensor( - [ - 0, - ] - + orders - ), - 0, - ).to(device) - ] - return timesteps_outer, orders - - def denoise_to_zero_fn(self, x, s): - """ - Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. - """ - return self.data_prediction_fn(x, s) - - def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): - """ - DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - dims = x.dim() - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) - sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - if self.algorithm_type == "dpmsolver++": - phi_1 = torch.expm1(-h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s - if return_intermediate: - return x_t, {"model_s": model_s} - else: - return x_t - else: - phi_1 = torch.expm1(h) - if model_s is None: - model_s = self.model_fn(x, s) - x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s - if return_intermediate: - return x_t, {"model_s": model_s} - else: - return x_t - - def singlestep_dpm_solver_second_update( - self, - x, - s, - t, - r1=0.5, - model_s=None, - return_intermediate=False, - solver_type="dpmsolver", - ): - """ - Singlestep solver DPM-Solver-2 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - r1: A `float`. The hyperparameter of the second-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 0.5 - ns = self.noise_schedule - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - s1 = ns.inverse_lambda(lambda_s1) - log_alpha_s, log_alpha_s1, log_alpha_t = ( - ns.marginal_log_mean_coeff(s), - ns.marginal_log_mean_coeff(s1), - ns.marginal_log_mean_coeff(t), - ) - sigma_s, sigma_s1, sigma_t = ( - ns.marginal_std(s), - ns.marginal_std(s1), - ns.marginal_std(t), - ) - alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) - - if self.algorithm_type == "dpmsolver++": - phi_11 = torch.expm1(-r1 * h) - phi_1 = torch.expm1(-h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - if solver_type == "dpmsolver": - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) - ) - elif solver_type == "taylor": - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_1 = torch.expm1(h) - - if model_s is None: - model_s = self.model_fn(x, s) - x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - if solver_type == "dpmsolver": - x_t = ( - torch.exp(log_alpha_t - log_alpha_s) * x - - (sigma_t * phi_1) * model_s - - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) - ) - elif solver_type == "taylor": - x_t = ( - torch.exp(log_alpha_t - log_alpha_s) * x - - (sigma_t * phi_1) * model_s - - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) - ) - if return_intermediate: - return x_t, {"model_s": model_s, "model_s1": model_s1} - else: - return x_t - - def singlestep_dpm_solver_third_update( - self, - x, - s, - t, - r1=1.0 / 3.0, - r2=2.0 / 3.0, - model_s=None, - model_s1=None, - return_intermediate=False, - solver_type="dpmsolver", - ): - """ - Singlestep solver DPM-Solver-3 from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - r1: A `float`. The hyperparameter of the third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - model_s: A pytorch tensor. The model function evaluated at time `s`. - If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. - model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). - If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) - if r1 is None: - r1 = 1.0 / 3.0 - if r2 is None: - r2 = 2.0 / 3.0 - ns = self.noise_schedule - lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) - h = lambda_t - lambda_s - lambda_s1 = lambda_s + r1 * h - lambda_s2 = lambda_s + r2 * h - s1 = ns.inverse_lambda(lambda_s1) - s2 = ns.inverse_lambda(lambda_s2) - log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( - ns.marginal_log_mean_coeff(s), - ns.marginal_log_mean_coeff(s1), - ns.marginal_log_mean_coeff(s2), - ns.marginal_log_mean_coeff(t), - ) - sigma_s, sigma_s1, sigma_s2, sigma_t = ( - ns.marginal_std(s), - ns.marginal_std(s1), - ns.marginal_std(s2), - ns.marginal_std(t), - ) - alpha_s1, alpha_s2, alpha_t = ( - torch.exp(log_alpha_s1), - torch.exp(log_alpha_s2), - torch.exp(log_alpha_t), - ) - - if self.algorithm_type == "dpmsolver++": - phi_11 = torch.expm1(-r1 * h) - phi_12 = torch.expm1(-r2 * h) - phi_1 = torch.expm1(-h) - phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 - phi_2 = phi_1 / h + 1.0 - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - (sigma_s2 / sigma_s) * x - - (alpha_s2 * phi_12) * model_s - + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == "dpmsolver": - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) - ) - elif solver_type == "taylor": - D1_0 = (1.0 / r1) * (model_s1 - model_s) - D1_1 = (1.0 / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - (sigma_t / sigma_s) * x - - (alpha_t * phi_1) * model_s - + (alpha_t * phi_2) * D1 - - (alpha_t * phi_3) * D2 - ) - else: - phi_11 = torch.expm1(r1 * h) - phi_12 = torch.expm1(r2 * h) - phi_1 = torch.expm1(h) - phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 - phi_2 = phi_1 / h - 1.0 - phi_3 = phi_2 / h - 0.5 - - if model_s is None: - model_s = self.model_fn(x, s) - if model_s1 is None: - x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s - model_s1 = self.model_fn(x_s1, s1) - x_s2 = ( - (torch.exp(log_alpha_s2 - log_alpha_s)) * x - - (sigma_s2 * phi_12) * model_s - - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) - ) - model_s2 = self.model_fn(x_s2, s2) - if solver_type == "dpmsolver": - x_t = ( - (torch.exp(log_alpha_t - log_alpha_s)) * x - - (sigma_t * phi_1) * model_s - - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) - ) - elif solver_type == "taylor": - D1_0 = (1.0 / r1) * (model_s1 - model_s) - D1_1 = (1.0 / r2) * (model_s2 - model_s) - D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) - D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) - x_t = ( - (torch.exp(log_alpha_t - log_alpha_s)) * x - - (sigma_t * phi_1) * model_s - - (sigma_t * phi_2) * D1 - - (sigma_t * phi_3) * D2 - ) - - if return_intermediate: - return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} - else: - return x_t - - def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): - """ - Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) - t: A pytorch tensor. The ending time, with the shape (1,). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if solver_type not in ["dpmsolver", "taylor"]: - raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) - ns = self.noise_schedule - model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] - t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] - lambda_prev_1, lambda_prev_0, lambda_t = ( - ns.marginal_lambda(t_prev_1), - ns.marginal_lambda(t_prev_0), - ns.marginal_lambda(t), - ) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0 = h_0 / h - D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) - if self.algorithm_type == "dpmsolver++": - phi_1 = torch.expm1(-h) - if solver_type == "dpmsolver": - x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 - elif solver_type == "taylor": - x_t = ( - (sigma_t / sigma_prev_0) * x - - (alpha_t * phi_1) * model_prev_0 - + (alpha_t * (phi_1 / h + 1.0)) * D1_0 - ) - else: - phi_1 = torch.expm1(h) - if solver_type == "dpmsolver": - x_t = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * phi_1) * model_prev_0 - - 0.5 * (sigma_t * phi_1) * D1_0 - ) - elif solver_type == "taylor": - x_t = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * phi_1) * model_prev_0 - - (sigma_t * (phi_1 / h - 1.0)) * D1_0 - ) - return x_t - - def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): - """ - Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) - t: A pytorch tensor. The ending time, with the shape (1,). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - ns = self.noise_schedule - model_prev_2, model_prev_1, model_prev_0 = model_prev_list - t_prev_2, t_prev_1, t_prev_0 = t_prev_list - lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( - ns.marginal_lambda(t_prev_2), - ns.marginal_lambda(t_prev_1), - ns.marginal_lambda(t_prev_0), - ns.marginal_lambda(t), - ) - log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) - sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) - alpha_t = torch.exp(log_alpha_t) - - h_1 = lambda_prev_1 - lambda_prev_2 - h_0 = lambda_prev_0 - lambda_prev_1 - h = lambda_t - lambda_prev_0 - r0, r1 = h_0 / h, h_1 / h - D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) - D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.algorithm_type == "dpmsolver++": - phi_1 = torch.expm1(-h) - phi_2 = phi_1 / h + 1.0 - phi_3 = phi_2 / h - 0.5 - x_t = ( - (sigma_t / sigma_prev_0) * x - - (alpha_t * phi_1) * model_prev_0 - + (alpha_t * phi_2) * D1 - - (alpha_t * phi_3) * D2 - ) - else: - phi_1 = torch.expm1(h) - phi_2 = phi_1 / h - 1.0 - phi_3 = phi_2 / h - 0.5 - x_t = ( - (torch.exp(log_alpha_t - log_alpha_prev_0)) * x - - (sigma_t * phi_1) * model_prev_0 - - (sigma_t * phi_2) * D1 - - (sigma_t * phi_3) * D2 - ) - return x_t - - def singlestep_dpm_solver_update( - self, - x, - s, - t, - order, - return_intermediate=False, - solver_type="dpmsolver", - r1=None, - r2=None, - ): - """ - Singlestep DPM-Solver with the order `order` from time `s` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - s: A pytorch tensor. The starting time, with the shape (1,). - t: A pytorch tensor. The ending time, with the shape (1,). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - r1: A `float`. The hyperparameter of the second-order or third-order solver. - r2: A `float`. The hyperparameter of the third-order solver. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) - elif order == 2: - return self.singlestep_dpm_solver_second_update( - x, - s, - t, - return_intermediate=return_intermediate, - solver_type=solver_type, - r1=r1, - ) - elif order == 3: - return self.singlestep_dpm_solver_third_update( - x, - s, - t, - return_intermediate=return_intermediate, - solver_type=solver_type, - r1=r1, - r2=r2, - ) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): - """ - Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. - - Args: - x: A pytorch tensor. The initial value at time `s`. - model_prev_list: A list of pytorch tensor. The previous computed model values. - t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) - t: A pytorch tensor. The ending time, with the shape (1,). - order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_t: A pytorch tensor. The approximated solution at time `t`. - """ - if order == 1: - return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) - elif order == 2: - return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - elif order == 3: - return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) - else: - raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) - - def dpm_solver_adaptive( - self, - x, - order, - t_T, - t_0, - h_init=0.05, - atol=0.0078, - rtol=0.05, - theta=0.9, - t_err=1e-5, - solver_type="dpmsolver", - ): - """ - The adaptive step size solver based on singlestep DPM-Solver. - - Args: - x: A pytorch tensor. The initial value at time `t_T`. - order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. - t_T: A `float`. The starting time of the sampling (default is T). - t_0: A `float`. The ending time of the sampling (default is epsilon). - h_init: A `float`. The initial step size (for logSNR). - atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. - rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. - theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. - t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the - current time and `t_0` is less than `t_err`. The default setting is 1e-5. - solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. - The type slightly impacts the performance. We recommend to use 'dpmsolver' type. - Returns: - x_0: A pytorch tensor. The approximated solution at time `t_0`. - - [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. - """ - ns = self.noise_schedule - s = t_T * torch.ones((1,)).to(x) - lambda_s = ns.marginal_lambda(s) - lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) - h = h_init * torch.ones_like(s).to(x) - x_prev = x - nfe = 0 - if order == 2: - r1 = 0.5 - - def lower_update(x, s, t): - return self.dpm_solver_first_update(x, s, t, return_intermediate=True) - - def higher_update(x, s, t, **kwargs): - return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) - - elif order == 3: - r1, r2 = 1.0 / 3.0, 2.0 / 3.0 - - def lower_update(x, s, t): - return self.singlestep_dpm_solver_second_update( - x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type - ) - - def higher_update(x, s, t, **kwargs): - return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) - - else: - raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) - while torch.abs((s - t_0)).mean() > t_err: - t = ns.inverse_lambda(lambda_s + h) - x_lower, lower_noise_kwargs = lower_update(x, s, t) - x_higher = higher_update(x, s, t, **lower_noise_kwargs) - delta = torch.max( - torch.ones_like(x).to(x) * atol, - rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)), - ) - - def norm_fn(v): - return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) - - E = norm_fn((x_higher - x_lower) / delta).max() - if torch.all(E <= 1.0): - x = x_higher - s = t - x_prev = x_lower - lambda_s = ns.marginal_lambda(s) - h = torch.min( - theta * h * torch.float_power(E, -1.0 / order).float(), - lambda_0 - lambda_s, - ) - nfe += order - print("adaptive solver nfe", nfe) - return x - - def add_noise(self, x, t, noise=None): - """ - Compute the noised input xt = alpha_t * x + sigma_t * noise. - - Args: - x: A `torch.Tensor` with shape `(batch_size, *shape)`. - t: A `torch.Tensor` with shape `(t_size,)`. - Returns: - xt with shape `(t_size, batch_size, *shape)`. - """ - alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) - if noise is None: - noise = torch.randn((t.shape[0], *x.shape), device=x.device) - x = x.reshape((-1, *x.shape)) - xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise - if t.shape[0] == 1: - return xt.squeeze(0) - else: - return xt - - def inverse( - self, - x, - steps=20, - t_start=None, - t_end=None, - order=2, - skip_type="time_uniform", - method="multistep", - lower_order_final=True, - denoise_to_zero=False, - solver_type="dpmsolver", - atol=0.0078, - rtol=0.05, - return_intermediate=False, - ): - """ - Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. - For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. - """ - t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start - t_T = self.noise_schedule.T if t_end is None else t_end - assert ( - t_0 > 0 and t_T > 0 - ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" - return self.sample( - x, - steps=steps, - t_start=t_0, - t_end=t_T, - order=order, - skip_type=skip_type, - method=method, - lower_order_final=lower_order_final, - denoise_to_zero=denoise_to_zero, - solver_type=solver_type, - atol=atol, - rtol=rtol, - return_intermediate=return_intermediate, - ) - - def sample( - self, - x, - steps=20, - t_start=None, - t_end=None, - order=2, - skip_type="time_uniform", - method="multistep", - lower_order_final=True, - denoise_to_zero=False, - solver_type="dpmsolver", - atol=0.0078, - rtol=0.05, - return_intermediate=False, - ): - """ - Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. - - ===================================================== - - We support the following algorithms for both noise prediction model and data prediction model: - - 'singlestep': - Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. - We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). - The total number of function evaluations (NFE) == `steps`. - Given a fixed NFE == `steps`, the sampling procedure is: - - If `order` == 1: - - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. - - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. - - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If `order` == 3: - - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. - - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. - - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. - - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. - - 'multistep': - Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. - We initialize the first `order` values by lower order multistep solvers. - Given a fixed NFE == `steps`, the sampling procedure is: - Denote K = steps. - - If `order` == 1: - - We use K steps of DPM-Solver-1 (i.e. DDIM). - - If `order` == 2: - - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. - - If `order` == 3: - - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. - - 'singlestep_fixed': - Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). - We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. - - 'adaptive': - Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). - We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. - You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs - (NFE) and the sample quality. - - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. - - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. - - ===================================================== - - Some advices for choosing the algorithm: - - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: - Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. - e.g., DPM-Solver: - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - e.g., DPM-Solver++: - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, - skip_type='time_uniform', method='singlestep') - - For **guided sampling with large guidance scale** by DPMs: - Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. - e.g. - >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, - skip_type='time_uniform', method='multistep') - - We support three types of `skip_type`: - - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** - - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. - - 'time_quadratic': quadratic time for the time steps. - - ===================================================== - Args: - x: A pytorch tensor. The initial value at time `t_start` - e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. - steps: A `int`. The total number of function evaluations (NFE). - t_start: A `float`. The starting time of the sampling. - If `T` is None, we use self.noise_schedule.T (default is 1.0). - t_end: A `float`. The ending time of the sampling. - If `t_end` is None, we use 1. / self.noise_schedule.total_N. - e.g. if total_N == 1000, we have `t_end` == 1e-3. - For discrete-time DPMs: - - We recommend `t_end` == 1. / self.noise_schedule.total_N. - For continuous-time DPMs: - - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. - order: A `int`. The order of DPM-Solver. - skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. - method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. - Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). - - This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and - score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID - for diffusion models sampling by diffusion SDEs for low-resolutional images - (such as CIFAR-10). However, we observed that such trick does not matter for - high-resolutional images. As it needs an additional NFE, we do not recommend - it for high-resolutional images. - lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. - Only valid for `method=multistep` and `steps < 15`. We empirically find that - this trick is a key to stabilizing the sampling by DPM-Solver with very few steps - (especially for steps <= 10). So we recommend to set it to be `True`. - solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. - atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. - return_intermediate: A `bool`. Whether to save the xt at each step. - When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. - Returns: - x_end: A pytorch tensor. The approximated solution at time `t_end`. - - """ - t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end - t_T = self.noise_schedule.T if t_start is None else t_start - assert ( - t_0 > 0 and t_T > 0 - ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" - if return_intermediate: - assert method in [ - "multistep", - "singlestep", - "singlestep_fixed", - ], "Cannot use adaptive solver when saving intermediate values" - if self.correcting_xt_fn is not None: - assert method in [ - "multistep", - "singlestep", - "singlestep_fixed", - ], "Cannot use adaptive solver when correcting_xt_fn is not None" - device = x.device - intermediates = [] - with torch.no_grad(): - if method == "adaptive": - x = self.dpm_solver_adaptive( - x, - order=order, - t_T=t_T, - t_0=t_0, - atol=atol, - rtol=rtol, - solver_type=solver_type, - ) - elif method == "multistep": - assert steps >= order - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) - assert timesteps.shape[0] - 1 == steps - # Init the initial values. - step = 0 - t = timesteps[step] - t_prev_list = [t] - model_prev_list = [self.model_fn(x, t)] - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - # Init the first `order` values by lower order multistep DPM-Solver. - for step in range(1, order): - t = timesteps[step] - x = self.multistep_dpm_solver_update( - x, - model_prev_list, - t_prev_list, - t, - step, - solver_type=solver_type, - ) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - t_prev_list.append(t) - model_prev_list.append(self.model_fn(x, t)) - # Compute the remaining values by `order`-th order multistep DPM-Solver. - for step in range(order, steps + 1): - t = timesteps[step] - # We only use lower order for steps < 10 - if lower_order_final and steps < 10: - step_order = min(order, steps + 1 - step) - else: - step_order = order - x = self.multistep_dpm_solver_update( - x, - model_prev_list, - t_prev_list, - t, - step_order, - solver_type=solver_type, - ) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - for i in range(order - 1): - t_prev_list[i] = t_prev_list[i + 1] - model_prev_list[i] = model_prev_list[i + 1] - t_prev_list[-1] = t - # We do not need to evaluate the final model value. - if step < steps: - model_prev_list[-1] = self.model_fn(x, t) - elif method in ["singlestep", "singlestep_fixed"]: - if method == "singlestep": - (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( - steps=steps, - order=order, - skip_type=skip_type, - t_T=t_T, - t_0=t_0, - device=device, - ) - elif method == "singlestep_fixed": - K = steps // order - orders = [ - order, - ] * K - timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) - for step, order in enumerate(orders): - s, t = timesteps_outer[step], timesteps_outer[step + 1] - timesteps_inner = self.get_time_steps( - skip_type=skip_type, - t_T=s.item(), - t_0=t.item(), - N=order, - device=device, - ) - lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) - h = lambda_inner[-1] - lambda_inner[0] - r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h - r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h - x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step) - if return_intermediate: - intermediates.append(x) - else: - raise ValueError("Got wrong method {}".format(method)) - if denoise_to_zero: - t = torch.ones((1,)).to(device) * t_0 - x = self.denoise_to_zero_fn(x, t) - if self.correcting_xt_fn is not None: - x = self.correcting_xt_fn(x, t, step + 1) - if return_intermediate: - intermediates.append(x) - if return_intermediate: - return x, intermediates - else: - return x - - -############################################################# -# other utility functions -############################################################# - - -def interpolate_fn(x, xp, yp): - """ - A piecewise linear function y = f(x), using xp and yp as keypoints. - We implement f(x) in a differentiable way (i.e. applicable for autograd). - The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) - - Args: - x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). - xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. - yp: PyTorch tensor with shape [C, K]. - Returns: - The function values f(x), with shape [N, C]. - """ - N, K = x.shape[0], xp.shape[1] - all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) - sorted_all_x, x_indices = torch.sort(all_x, dim=2) - x_idx = torch.argmin(x_indices, dim=2) - cand_start_idx = x_idx - 1 - start_idx = torch.where( - torch.eq(x_idx, 0), - torch.tensor(1, device=x.device), - torch.where( - torch.eq(x_idx, K), - torch.tensor(K - 2, device=x.device), - cand_start_idx, - ), - ) - end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) - start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) - end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) - start_idx2 = torch.where( - torch.eq(x_idx, 0), - torch.tensor(0, device=x.device), - torch.where( - torch.eq(x_idx, K), - torch.tensor(K - 2, device=x.device), - cand_start_idx, - ), - ) - y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) - start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) - end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) - cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) - return cand - - -def expand_dims(v, dims): - """ - Expand the tensor `v` to the dim `dims`. - - Args: - `v`: a PyTorch tensor with shape [N]. - `dim`: a `int`. - Returns: - a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. - """ - return v[(...,) + (None,) * (dims - 1)] diff --git a/TTS/tts/layers/xtts/diffusion.py b/TTS/tts/layers/xtts/diffusion.py deleted file mode 100644 index 37665bc676..0000000000 --- a/TTS/tts/layers/xtts/diffusion.py +++ /dev/null @@ -1,1319 +0,0 @@ -import enum -import math - -import numpy as np -import torch -import torch as th -from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral -from tqdm import tqdm - -from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper - -K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} -SAMPLERS = ["dpm++2m", "p", "ddim"] - - -def normal_kl(mean1, logvar1, mean2, logvar2): - """ - Compute the KL divergence between two gaussians. - - Shapes are automatically broadcasted, so batches can be compared to - scalars, among other use cases. - """ - tensor = None - for obj in (mean1, logvar1, mean2, logvar2): - if isinstance(obj, th.Tensor): - tensor = obj - break - assert tensor is not None, "at least one argument must be a Tensor" - - # Force variances to be Tensors. Broadcasting helps convert scalars to - # Tensors, but it does not work for th.exp(). - logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] - - return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) - - -def approx_standard_normal_cdf(x): - """ - A fast approximation of the cumulative distribution function of the - standard normal. - """ - return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) - - -def discretized_gaussian_log_likelihood(x, *, means, log_scales): - """ - Compute the log-likelihood of a Gaussian distribution discretizing to a - given image. - - :param x: the target images. It is assumed that this was uint8 values, - rescaled to the range [-1, 1]. - :param means: the Gaussian mean Tensor. - :param log_scales: the Gaussian log stddev Tensor. - :return: a tensor like x of log probabilities (in nats). - """ - assert x.shape == means.shape == log_scales.shape - centered_x = x - means - inv_stdv = th.exp(-log_scales) - plus_in = inv_stdv * (centered_x + 1.0 / 255.0) - cdf_plus = approx_standard_normal_cdf(plus_in) - min_in = inv_stdv * (centered_x - 1.0 / 255.0) - cdf_min = approx_standard_normal_cdf(min_in) - log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) - log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) - cdf_delta = cdf_plus - cdf_min - log_probs = th.where( - x < -0.999, - log_cdf_plus, - th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), - ) - assert log_probs.shape == x.shape - return log_probs - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): - """ - Get a pre-defined beta schedule for the given name. - - The beta schedule library consists of beta schedules which remain similar - in the limit of num_diffusion_timesteps. - Beta schedules may be added, but should not be removed or changed once - they are committed to maintain backwards compatibility. - """ - if schedule_name == "linear": - # Linear schedule from Ho et al, extended to work for any number of - # diffusion steps. - scale = 1000 / num_diffusion_timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) - elif schedule_name == "cosine": - return betas_for_alpha_bar( - num_diffusion_timesteps, - lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, - ) - else: - raise NotImplementedError(f"unknown beta schedule: {schedule_name}") - - -def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, - which defines the cumulative product of (1-beta) over time from t = [0,1]. - - :param num_diffusion_timesteps: the number of betas to produce. - :param alpha_bar: a lambda that takes an argument t from 0 to 1 and - produces the cumulative product of (1-beta) up to that - part of the diffusion process. - :param max_beta: the maximum beta to use; use values lower than 1 to - prevent singularities. - """ - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return np.array(betas) - - -class ModelMeanType(enum.Enum): - """ - Which type of output the model predicts. - """ - - PREVIOUS_X = "previous_x" # the model predicts x_{t-1} - START_X = "start_x" # the model predicts x_0 - EPSILON = "epsilon" # the model predicts epsilon - - -class ModelVarType(enum.Enum): - """ - What is used as the model's output variance. - - The LEARNED_RANGE option has been added to allow the model to predict - values between FIXED_SMALL and FIXED_LARGE, making its job easier. - """ - - LEARNED = "learned" - FIXED_SMALL = "fixed_small" - FIXED_LARGE = "fixed_large" - LEARNED_RANGE = "learned_range" - - -class LossType(enum.Enum): - MSE = "mse" # use raw MSE loss (and KL when learning variances) - RESCALED_MSE = "rescaled_mse" # use raw MSE loss (with RESCALED_KL when learning variances) - KL = "kl" # use the variational lower-bound - RESCALED_KL = "rescaled_kl" # like KL, but rescale to estimate the full VLB - - def is_vb(self): - return self == LossType.KL or self == LossType.RESCALED_KL - - -class GaussianDiffusion: - """ - Utilities for training and sampling diffusion models. - - Ported directly from here, and then adapted over time to further experimentation. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 - - :param betas: a 1-D numpy array of betas for each diffusion timestep, - starting at T and going to 1. - :param model_mean_type: a ModelMeanType determining what the model outputs. - :param model_var_type: a ModelVarType determining how variance is output. - :param loss_type: a LossType determining the loss function to use. - :param rescale_timesteps: if True, pass floating point timesteps into the - model so that they are always scaled like in the - original paper (0 to 1000). - """ - - def __init__( - self, - *, - betas, - model_mean_type, - model_var_type, - loss_type, - rescale_timesteps=False, # this is generally False - conditioning_free=False, - conditioning_free_k=1, - ramp_conditioning_free=True, - sampler="ddim", - ): - self.sampler = sampler - self.model_mean_type = ModelMeanType(model_mean_type) - self.model_var_type = ModelVarType(model_var_type) - self.loss_type = LossType(loss_type) - self.rescale_timesteps = rescale_timesteps - self.conditioning_free = conditioning_free - self.conditioning_free_k = conditioning_free_k - self.ramp_conditioning_free = ramp_conditioning_free - - # Use float64 for accuracy. - betas = np.array(betas, dtype=np.float64) - self.betas = betas - assert len(betas.shape) == 1, "betas must be 1-D" - assert (betas > 0).all() and (betas <= 1).all() - - self.num_timesteps = int(betas.shape[0]) - - alphas = 1.0 - betas - self.alphas_cumprod = np.cumprod(alphas, axis=0) - self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) - self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) - assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) - - # calculations for diffusion q(x_t | x_{t-1}) and others - self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) - self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) - self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) - self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) - self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - # log calculation clipped because the posterior variance is 0 at the - # beginning of the diffusion chain. - self.posterior_log_variance_clipped = np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:])) - self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) - self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) - - def q_mean_variance(self, x_start, t): - """ - Get the distribution q(x_t | x_0). - - :param x_start: the [N x C x ...] tensor of noiseless inputs. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :return: A tuple (mean, variance, log_variance), all of x_start's shape. - """ - mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) - log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def q_sample(self, x_start, t, noise=None): - """ - Diffuse the data for a given number of diffusion steps. - - In other words, sample from q(x_t | x_0). - - :param x_start: the initial data batch. - :param t: the number of diffusion steps (minus 1). Here, 0 means one step. - :param noise: if specified, the split-out normal noise. - :return: A noisy version of x_start. - """ - if noise is None: - noise = th.randn_like(x_start) - assert noise.shape == x_start.shape - return ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - def q_posterior_mean_variance(self, x_start, x_t, t): - """ - Compute the mean and variance of the diffusion posterior: - - q(x_{t-1} | x_t, x_0) - - """ - assert x_start.shape == x_t.shape - posterior_mean = ( - _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) - assert ( - posterior_mean.shape[0] - == posterior_variance.shape[0] - == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] - ) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): - """ - Apply the model to get p(x_{t-1} | x_t), as well as a prediction of - the initial x, x_0. - - :param model: the model, which takes a signal and a batch of timesteps - as input. - :param x: the [N x C x ...] tensor at time t. - :param t: a 1-D Tensor of timesteps. - :param clip_denoised: if True, clip the denoised signal into [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. Applies before - clip_denoised. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict with the following keys: - - 'mean': the model mean output. - - 'variance': the model variance output. - - 'log_variance': the log of 'variance'. - - 'pred_xstart': the prediction for x_0. - """ - if model_kwargs is None: - model_kwargs = {} - - assert self.model_var_type == ModelVarType.LEARNED_RANGE - assert self.model_mean_type == ModelMeanType.EPSILON - assert denoised_fn is None - assert clip_denoised is True - B, C = x.shape[:2] - assert t.shape == (B,) - model_output = model(x, self._scale_timesteps(t), **model_kwargs) - if self.conditioning_free: - model_output_no_conditioning = model(x, self._scale_timesteps(t), conditioning_free=True, **model_kwargs) - - if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: - assert model_output.shape == (B, C * 2, *x.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) - if self.conditioning_free: - model_output_no_conditioning, _ = th.split(model_output_no_conditioning, C, dim=1) - if self.model_var_type == ModelVarType.LEARNED: - assert False - model_log_variance = model_var_values - model_variance = th.exp(model_log_variance) - else: - min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) - max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) - # The model_var_values is [-1, 1] for [min_var, max_var]. - frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log - model_variance = th.exp(model_log_variance) - else: - assert False - model_variance, model_log_variance = { - # for fixedlarge, we set the initial (log-)variance like so - # to get a better decoder log likelihood. - ModelVarType.FIXED_LARGE: ( - np.append(self.posterior_variance[1], self.betas[1:]), - np.log(np.append(self.posterior_variance[1], self.betas[1:])), - ), - ModelVarType.FIXED_SMALL: ( - self.posterior_variance, - self.posterior_log_variance_clipped, - ), - }[self.model_var_type] - model_variance = _extract_into_tensor(model_variance, t, x.shape) - model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) - - if self.conditioning_free: - if self.ramp_conditioning_free: - assert t.shape[0] == 1 # This should only be used in inference. - cfk = self.conditioning_free_k * (1 - self._scale_timesteps(t)[0].item() / self.num_timesteps) - else: - cfk = self.conditioning_free_k - model_output = (1 + cfk) * model_output - cfk * model_output_no_conditioning - - def process_xstart(x): - if denoised_fn is not None: - assert False - x = denoised_fn(x) - if clip_denoised: - return x.clamp(-1, 1) - assert False - return x - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - assert False - pred_xstart = process_xstart(self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)) - model_mean = model_output - elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: - if self.model_mean_type == ModelMeanType.START_X: - assert False - pred_xstart = process_xstart(model_output) - else: - pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)) - model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) - else: - raise NotImplementedError(self.model_mean_type) - - assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape - return { - "mean": model_mean, - "variance": model_variance, - "log_variance": model_log_variance, - "pred_xstart": pred_xstart, - } - - def _predict_xstart_from_eps(self, x_t, t, eps): - assert x_t.shape == eps.shape - return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps - ) - - def _predict_xstart_from_xprev(self, x_t, t, xprev): - assert x_t.shape == xprev.shape - return ( # (xprev - coef2*x_t) / coef1 - _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - - _extract_into_tensor(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t - ) - - def _predict_eps_from_xstart(self, x_t, t, pred_xstart): - return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) - - def _scale_timesteps(self, t): - if self.rescale_timesteps: - return t.float() * (1000.0 / self.num_timesteps) - return t - - def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): - """ - Compute the mean for the previous step, given a function cond_fn that - computes the gradient of a conditional log probability with respect to - x. In particular, cond_fn computes grad(log(p(y|x))), and we want to - condition on y. - - This uses the conditioning strategy from Sohl-Dickstein et al. (2015). - """ - gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs) - new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() - return new_mean - - def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): - """ - Compute what the p_mean_variance output would have been, should the - model's score function be conditioned by cond_fn. - - See condition_mean() for details on cond_fn. - - Unlike condition_mean(), this instead uses the conditioning strategy - from Song et al (2020). - """ - alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) - - eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) - eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, self._scale_timesteps(t), **model_kwargs) - - out = p_mean_var.copy() - out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) - out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) - return out - - def p_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - ): - """ - Sample x_{t-1} from the model at the given timestep. - - :param model: the model to sample from. - :param x: the current tensor at x_{t-1}. - :param t: the value of t, starting at 0 for the first diffusion step. - :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. - :param cond_fn: if not None, this is a gradient function that acts - similarly to the model. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :return: a dict containing the following keys: - - 'sample': a random sample from the model. - - 'pred_xstart': a prediction of x_0. - """ - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - noise = th.randn_like(x) - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 - if cond_fn is not None: - out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) - sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"]} - - def k_diffusion_sample_loop( - self, - k_sampler, - pbar, - model, - shape, - noise=None, # all given - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - device=None, # ALL UNUSED - model_kwargs=None, # {'precomputed_aligned_embeddings': precomputed_embeddings}, - progress=False, # unused as well - ): - assert isinstance(model_kwargs, dict) - if device is None: - device = next(model.parameters()).device - s_in = noise.new_ones([noise.shape[0]]) - - def model_split(*args, **kwargs): - model_output = model(*args, **kwargs) - model_epsilon, model_var = th.split(model_output, model_output.shape[1] // 2, dim=1) - return model_epsilon, model_var - - # - """ - print(self.betas) - print(th.tensor(self.betas)) - noise_schedule = NoiseScheduleVP(schedule='discrete', betas=th.tensor(self.betas)) - """ - noise_schedule = NoiseScheduleVP(schedule="linear", continuous_beta_0=0.1 / 4, continuous_beta_1=20.0 / 4) - - def model_fn_prewrap(x, t, *args, **kwargs): - """ - x_in = torch.cat([x] * 2) - t_in = torch.cat([t_continuous] * 2) - c_in = torch.cat([unconditional_condition, condition]) - noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) - print(t) - print(self.timestep_map) - exit() - """ - """ - model_output = model(x, self._scale_timesteps(t*4000), **model_kwargs) - out = self.p_mean_variance(model, x, t*4000, model_kwargs=model_kwargs) - return out['pred_xstart'] - """ - x, _ = x.chunk(2) - t, _ = (t * 1000).chunk(2) - res = torch.cat( - [ - model_split(x, t, conditioning_free=True, **model_kwargs)[0], - model_split(x, t, **model_kwargs)[0], - ] - ) - pbar.update(1) - return res - - model_fn = model_wrapper( - model_fn_prewrap, - noise_schedule, - model_type="noise", # "noise" or "x_start" or "v" or "score" - model_kwargs=model_kwargs, - guidance_type="classifier-free", - condition=th.Tensor(1), - unconditional_condition=th.Tensor(1), - guidance_scale=self.conditioning_free_k, - ) - """ - model_fn = model_wrapper( - model_fn_prewrap, - noise_schedule, - model_type='x_start', - model_kwargs={} - ) - # - dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") - x_sample = dpm_solver.sample( - noise, - steps=20, - order=3, - skip_type="time_uniform", - method="singlestep", - ) - """ - dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") - x_sample = dpm_solver.sample( - noise, - steps=self.num_timesteps, - order=2, - skip_type="time_uniform", - method="multistep", - ) - #''' - return x_sample - - # HF DIFFUSION ATTEMPT - """ - from .hf_diffusion import EulerAncestralDiscreteScheduler - Scheduler = EulerAncestralDiscreteScheduler() - Scheduler.set_timesteps(100) - for timestep in Scheduler.timesteps: - noise_input = Scheduler.scale_model_input(noise, timestep) - ts = s_in * timestep - model_output = model(noise_input, ts, **model_kwargs) - model_epsilon, _model_var = th.split(model_output, model_output.shape[1]//2, dim=1) - noise, _x0 = Scheduler.step(model_epsilon, timestep, noise) - return noise - """ - - # KARRAS DIFFUSION ATTEMPT - """ - TRAINED_DIFFUSION_STEPS = 4000 # HARDCODED - ratio = TRAINED_DIFFUSION_STEPS/14.5 - def call_model(*args, **kwargs): - model_output = model(*args, **kwargs) - model_output, model_var_values = th.split(model_output, model_output.shape[1]//2, dim=1) - return model_output - print(get_sigmas_karras(self.num_timesteps, sigma_min=0.0, sigma_max=4000, device=device)) - exit() - sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device) - return k_sampler(call_model, noise, sigmas, extra_args=model_kwargs, disable=not progress) - ''' - sigmas = get_sigmas_karras(self.num_timesteps, sigma_min=0.03, sigma_max=14.5, device=device) - step = 0 # LMAO - global_sigmas = None - # - def fakemodel(x, t, **model_kwargs): - print(t,global_sigmas*ratio) - return model(x, t, **model_kwargs) - def denoised(x, sigmas, **extra_args): - t = th.tensor([self.num_timesteps-step-1] * shape[0], device=device) - nonlocal global_sigmas - global_sigmas = sigmas - with th.no_grad(): - out = self.p_sample( - fakemodel, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - ) - return out["sample"] - def callback(d): - nonlocal step - step += 1 - - return k_sampler(denoised, noise, sigmas, extra_args=model_kwargs, callback=callback, disable=not progress) - ''' - """ - - def sample_loop(self, *args, **kwargs): - s = self.sampler - if s == "p": - return self.p_sample_loop(*args, **kwargs) - elif s == "ddim": - return self.ddim_sample_loop(*args, **kwargs) - elif s == "dpm++2m": - if self.conditioning_free is not True: - raise RuntimeError("cond_free must be true") - with tqdm(total=self.num_timesteps) as pbar: - return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) - else: - raise RuntimeError("sampler not impl") - - def p_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - ): - """ - Generate samples from the model. - - :param model: the model module. - :param shape: the shape of the samples, (N, C, H, W). - :param noise: if specified, the noise from the encoder to sample. - Should be of the same shape as `shape`. - :param clip_denoised: if True, clip x_start predictions to [-1, 1]. - :param denoised_fn: if not None, a function which applies to the - x_start prediction before it is used to sample. - :param cond_fn: if not None, this is a gradient function that acts - similarly to the model. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param device: if specified, the device to create the samples on. - If not specified, use a model parameter's device. - :param progress: if True, show a tqdm progress bar. - :return: a non-differentiable batch of samples. - """ - final = None - for sample in self.p_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - device=device, - progress=progress, - ): - final = sample - return final["sample"] - - def p_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - ): - """ - Generate samples from the model and yield intermediate samples from - each timestep of diffusion. - - Arguments are the same as p_sample_loop(). - Returns a generator over dicts, where each dict is the return value of - p_sample(). - """ - if device is None: - device = next(model.parameters()).device - assert isinstance(shape, (tuple, list)) - if noise is not None: - img = noise - else: - img = th.randn(*shape, device=device) - indices = list(range(self.num_timesteps))[::-1] - - for i in tqdm(indices, disable=not progress): - t = th.tensor([i] * shape[0], device=device) - with th.no_grad(): - out = self.p_sample( - model, - img, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - ) - yield out - img = out["sample"] - - def ddim_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - eta=0.0, - ): - """ - Sample x_{t-1} from the model using DDIM. - - Same usage as p_sample(). - """ - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - if cond_fn is not None: - out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) - - # Usually our model outputs epsilon, but we re-derive it - # in case we used x_start or x_prev prediction. - eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) - - alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) - alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) - sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev) - # Equation 12. - noise = th.randn_like(x) - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps - nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0 - sample = mean_pred + nonzero_mask * sigma * noise - return {"sample": sample, "pred_xstart": out["pred_xstart"]} - - def ddim_reverse_sample( - self, - model, - x, - t, - clip_denoised=True, - denoised_fn=None, - model_kwargs=None, - eta=0.0, - ): - """ - Sample x_{t+1} from the model using DDIM reverse ODE. - """ - assert eta == 0.0, "Reverse ODE only for deterministic path" - out = self.p_mean_variance( - model, - x, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - model_kwargs=model_kwargs, - ) - # Usually our model outputs epsilon, but we re-derive it - # in case we used x_start or x_prev prediction. - eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"] - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) - alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) - - # Equation 12. reversed - mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps - - return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} - - def ddim_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, - ): - """ - Generate samples from the model using DDIM. - - Same usage as p_sample_loop(). - """ - final = None - for sample in self.ddim_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - device=device, - progress=progress, - eta=eta, - ): - final = sample - return final["sample"] - - def ddim_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - denoised_fn=None, - cond_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, - ): - """ - Use DDIM to sample from the model and yield intermediate samples from - each timestep of DDIM. - - Same usage as p_sample_loop_progressive(). - """ - if device is None: - device = next(model.parameters()).device - assert isinstance(shape, (tuple, list)) - if noise is not None: - img = noise - else: - img = th.randn(*shape, device=device) - indices = list(range(self.num_timesteps))[::-1] - - if progress: - # Lazy import so that we don't depend on tqdm. - from tqdm.auto import tqdm - - indices = tqdm(indices, disable=not progress) - - for i in indices: - t = th.tensor([i] * shape[0], device=device) - with th.no_grad(): - out = self.ddim_sample( - model, - img, - t, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - cond_fn=cond_fn, - model_kwargs=model_kwargs, - eta=eta, - ) - yield out - img = out["sample"] - - def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None): - """ - Get a term for the variational lower-bound. - - The resulting units are bits (rather than nats, as one might expect). - This allows for comparison to other papers. - - :return: a dict with the following keys: - - 'output': a shape [N] tensor of NLLs or KLs. - - 'pred_xstart': the x_0 predictions. - """ - true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t) - out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs) - kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]) - kl = mean_flat(kl) / np.log(2.0) - - decoder_nll = -discretized_gaussian_log_likelihood( - x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] - ) - assert decoder_nll.shape == x_start.shape - decoder_nll = mean_flat(decoder_nll) / np.log(2.0) - - # At the first timestep return the decoder NLL, - # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) - output = th.where((t == 0), decoder_nll, kl) - return {"output": output, "pred_xstart": out["pred_xstart"]} - - def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): - """ - Compute training losses for a single timestep. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param t: a batch of timestep indices. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param noise: if specified, the specific Gaussian noise to try to remove. - :return: a dict with the key "loss" containing a tensor of shape [N]. - Some mean or variance settings may also have other keys. - """ - if model_kwargs is None: - model_kwargs = {} - if noise is None: - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) - - terms = {} - - if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: - # TODO: support multiple model outputs for this mode. - terms["loss"] = self._vb_terms_bpd( - model=model, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - model_kwargs=model_kwargs, - )["output"] - if self.loss_type == LossType.RESCALED_KL: - terms["loss"] *= self.num_timesteps - elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_outputs = model(x_t, self._scale_timesteps(t), **model_kwargs) - if isinstance(model_outputs, tuple): - model_output = model_outputs[0] - terms["extra_outputs"] = model_outputs[1:] - else: - model_output = model_outputs - - if self.model_var_type in [ - ModelVarType.LEARNED, - ModelVarType.LEARNED_RANGE, - ]: - B, C = x_t.shape[:2] - assert model_output.shape == (B, C * 2, *x_t.shape[2:]) - model_output, model_var_values = th.split(model_output, C, dim=1) - # Learn the variance using the variational bound, but don't let - # it affect our mean prediction. - frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) - terms["vb"] = self._vb_terms_bpd( - model=lambda *args, r=frozen_out: r, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - )["output"] - if self.loss_type == LossType.RESCALED_MSE: - # Divide by 1000 for equivalence with initial implementation. - # Without a factor of 1/1000, the VB term hurts the MSE term. - terms["vb"] *= self.num_timesteps / 1000.0 - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] - x_start_pred = torch.zeros(x_start) # Not supported. - elif self.model_mean_type == ModelMeanType.START_X: - target = x_start - x_start_pred = model_output - elif self.model_mean_type == ModelMeanType.EPSILON: - target = noise - x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) - else: - raise NotImplementedError(self.model_mean_type) - assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) - terms["x_start_predicted"] = x_start_pred - if "vb" in terms: - terms["loss"] = terms["mse"] + terms["vb"] - else: - terms["loss"] = terms["mse"] - else: - raise NotImplementedError(self.loss_type) - - return terms - - def autoregressive_training_losses( - self, - model, - x_start, - t, - model_output_keys, - gd_out_key, - model_kwargs=None, - noise=None, - ): - """ - Compute training losses for a single timestep. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param t: a batch of timestep indices. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - :param noise: if specified, the specific Gaussian noise to try to remove. - :return: a dict with the key "loss" containing a tensor of shape [N]. - Some mean or variance settings may also have other keys. - """ - if model_kwargs is None: - model_kwargs = {} - if noise is None: - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start, t, noise=noise) - terms = {} - if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: - assert False # not currently supported for this type of diffusion. - elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: - model_outputs = model(x_t, x_start, self._scale_timesteps(t), **model_kwargs) - terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) - model_output = terms[gd_out_key] - if self.model_var_type in [ - ModelVarType.LEARNED, - ModelVarType.LEARNED_RANGE, - ]: - B, C = x_t.shape[:2] - assert model_output.shape == (B, C, 2, *x_t.shape[2:]) - model_output, model_var_values = ( - model_output[:, :, 0], - model_output[:, :, 1], - ) - # Learn the variance using the variational bound, but don't let - # it affect our mean prediction. - frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) - terms["vb"] = self._vb_terms_bpd( - model=lambda *args, r=frozen_out: r, - x_start=x_start, - x_t=x_t, - t=t, - clip_denoised=False, - )["output"] - if self.loss_type == LossType.RESCALED_MSE: - # Divide by 1000 for equivalence with initial implementation. - # Without a factor of 1/1000, the VB term hurts the MSE term. - terms["vb"] *= self.num_timesteps / 1000.0 - - if self.model_mean_type == ModelMeanType.PREVIOUS_X: - target = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0] - x_start_pred = torch.zeros(x_start) # Not supported. - elif self.model_mean_type == ModelMeanType.START_X: - target = x_start - x_start_pred = model_output - elif self.model_mean_type == ModelMeanType.EPSILON: - target = noise - x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) - else: - raise NotImplementedError(self.model_mean_type) - assert model_output.shape == target.shape == x_start.shape - terms["mse"] = mean_flat((target - model_output) ** 2) - terms["x_start_predicted"] = x_start_pred - if "vb" in terms: - terms["loss"] = terms["mse"] + terms["vb"] - else: - terms["loss"] = terms["mse"] - else: - raise NotImplementedError(self.loss_type) - - return terms - - def _prior_bpd(self, x_start): - """ - Get the prior KL term for the variational lower-bound, measured in - bits-per-dim. - - This term can't be optimized, as it only depends on the encoder. - - :param x_start: the [N x C x ...] tensor of inputs. - :return: a batch of [N] KL values (in bits), one per batch element. - """ - batch_size = x_start.shape[0] - t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) - qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) - kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) - return mean_flat(kl_prior) / np.log(2.0) - - def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): - """ - Compute the entire variational lower-bound, measured in bits-per-dim, - as well as other related quantities. - - :param model: the model to evaluate loss on. - :param x_start: the [N x C x ...] tensor of inputs. - :param clip_denoised: if True, clip denoised samples. - :param model_kwargs: if not None, a dict of extra keyword arguments to - pass to the model. This can be used for conditioning. - - :return: a dict containing the following keys: - - total_bpd: the total variational lower-bound, per batch element. - - prior_bpd: the prior term in the lower-bound. - - vb: an [N x T] tensor of terms in the lower-bound. - - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. - - mse: an [N x T] tensor of epsilon MSEs for each timestep. - """ - device = x_start.device - batch_size = x_start.shape[0] - - vb = [] - xstart_mse = [] - mse = [] - for t in list(range(self.num_timesteps))[::-1]: - t_batch = th.tensor([t] * batch_size, device=device) - noise = th.randn_like(x_start) - x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) - # Calculate VLB term at the current timestep - with th.no_grad(): - out = self._vb_terms_bpd( - model, - x_start=x_start, - x_t=x_t, - t=t_batch, - clip_denoised=clip_denoised, - model_kwargs=model_kwargs, - ) - vb.append(out["output"]) - xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) - eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) - mse.append(mean_flat((eps - noise) ** 2)) - - vb = th.stack(vb, dim=1) - xstart_mse = th.stack(xstart_mse, dim=1) - mse = th.stack(mse, dim=1) - - prior_bpd = self._prior_bpd(x_start) - total_bpd = vb.sum(dim=1) + prior_bpd - return { - "total_bpd": total_bpd, - "prior_bpd": prior_bpd, - "vb": vb, - "xstart_mse": xstart_mse, - "mse": mse, - } - - -class SpacedDiffusion(GaussianDiffusion): - """ - A diffusion process which can skip steps in a base diffusion process. - - :param use_timesteps: a collection (sequence or set) of timesteps from the - original diffusion process to retain. - :param kwargs: the kwargs to create the base diffusion process. - """ - - def __init__(self, use_timesteps, **kwargs): - self.use_timesteps = set(use_timesteps) - self.timestep_map = [] - self.original_num_steps = len(kwargs["betas"]) - - base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa - last_alpha_cumprod = 1.0 - new_betas = [] - for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): - if i in self.use_timesteps: - new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) - last_alpha_cumprod = alpha_cumprod - self.timestep_map.append(i) - kwargs["betas"] = np.array(new_betas) - super().__init__(**kwargs) - - def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) - - def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().training_losses(self._wrap_model(model), *args, **kwargs) - - def autoregressive_training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs - return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) - - def condition_mean(self, cond_fn, *args, **kwargs): - return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) - - def condition_score(self, cond_fn, *args, **kwargs): - return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) - - def _wrap_model(self, model, autoregressive=False): - if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): - return model - mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel - return mod(model, self.timestep_map, self.rescale_timesteps, self.original_num_steps) - - def _scale_timesteps(self, t): - # Scaling is done by the wrapped model. - return t - - -def space_timesteps(num_timesteps, section_counts): - """ - Create a list of timesteps to use from an original diffusion process, - given the number of timesteps we want to take from equally-sized portions - of the original process. - - For example, if there's 300 timesteps and the section counts are [10,15,20] - then the first 100 timesteps are strided to be 10 timesteps, the second 100 - are strided to be 15 timesteps, and the final 100 are strided to be 20. - - If the stride is a string starting with "ddim", then the fixed striding - from the DDIM paper is used, and only one section is allowed. - - :param num_timesteps: the number of diffusion steps in the original - process to divide up. - :param section_counts: either a list of numbers, or a string containing - comma-separated numbers, indicating the step count - per section. As a special case, use "ddimN" where N - is a number of steps to use the striding from the - DDIM paper. - :return: a set of diffusion steps from the original process to use. - """ - if isinstance(section_counts, str): - if section_counts.startswith("ddim"): - desired_count = int(section_counts[len("ddim") :]) - for i in range(1, num_timesteps): - if len(range(0, num_timesteps, i)) == desired_count: - return set(range(0, num_timesteps, i)) - raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") - section_counts = [int(x) for x in section_counts.split(",")] - size_per = num_timesteps // len(section_counts) - extra = num_timesteps % len(section_counts) - start_idx = 0 - all_steps = [] - for i, section_count in enumerate(section_counts): - size = size_per + (1 if i < extra else 0) - if size < section_count: - raise ValueError(f"cannot divide section of {size} steps into {section_count}") - if section_count <= 1: - frac_stride = 1 - else: - frac_stride = (size - 1) / (section_count - 1) - cur_idx = 0.0 - taken_steps = [] - for _ in range(section_count): - taken_steps.append(start_idx + round(cur_idx)) - cur_idx += frac_stride - all_steps += taken_steps - start_idx += size - return set(all_steps) - - -class _WrappedModel: - def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): - self.model = model - self.timestep_map = timestep_map - self.rescale_timesteps = rescale_timesteps - self.original_num_steps = original_num_steps - - def __call__(self, x, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) - new_ts = map_tensor[ts] - if self.rescale_timesteps: - new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, new_ts, **kwargs) - - -class _WrappedAutoregressiveModel: - def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): - self.model = model - self.timestep_map = timestep_map - self.rescale_timesteps = rescale_timesteps - self.original_num_steps = original_num_steps - - def __call__(self, x, x0, ts, **kwargs): - map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) - new_ts = map_tensor[ts] - if self.rescale_timesteps: - new_ts = new_ts.float() * (1000.0 / self.original_num_steps) - return self.model(x, x0, new_ts, **kwargs) - - -def _extract_into_tensor(arr, timesteps, broadcast_shape): - """ - Extract values from a 1-D numpy array for a batch of indices. - - :param arr: the 1-D numpy array. - :param timesteps: a tensor of indices into the array to extract. - :param broadcast_shape: a larger shape of K dimensions with the batch - dimension equal to the length of timesteps. - :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. - """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() - while len(res.shape) < len(broadcast_shape): - res = res[..., None] - return res.expand(broadcast_shape) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 477f31bfc9..6b8a73e859 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -9,8 +9,6 @@ from coqpit import Coqpit from TTS.tts.layers.tortoise.audio_utils import denormalize_tacotron_mel, wav_to_univnet_mel -from TTS.tts.layers.tortoise.diffusion_decoder import DiffusionTts -from TTS.tts.layers.xtts.diffusion import SpacedDiffusion, get_named_beta_schedule, space_timesteps from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support @@ -168,12 +166,10 @@ class XttsAudioConfig(Coqpit): Args: sample_rate (int): The sample rate in which the GPT operates. - diffusion_sample_rate (int): The sample rate of the diffusion audio waveform. output_sample_rate (int): The sample rate of the output audio waveform. """ sample_rate: int = 22050 - diffusion_sample_rate: int = 24000 output_sample_rate: int = 24000 @@ -697,24 +693,11 @@ def inference( hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) - else: - assert hasattr( - self, "diffusion_decoder" - ), "You must disable hifigan decoders to use difffusion by setting `use_hifigan: false`" - mel = do_spectrogram_diffusion( - self.diffusion_decoder, - diffuser, - gpt_latents, - diffusion_conditioning, - temperature=diffusion_temperature, - ) - wav = self.vocoder.inference(mel) return { "wav": wav.cpu().numpy().squeeze(), "gpt_latents": gpt_latents, "speaker_embedding": speaker_embedding, - "diffusion_conditioning": diffusion_conditioning, } def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 9134be0db2..65d3ccd04d 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -98,7 +98,7 @@ def main(): ) # define audio config audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) # training parameters config config = GPTTrainerConfig( diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index ee6b22becd..3bb68e2f3f 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -99,7 +99,7 @@ def main(): ) # define audio config audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) # training parameters config config = GPTTrainerConfig( diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 47b1dd7d27..09df98eff6 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -89,7 +89,7 @@ use_ne_hifigan=True, ) audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) config = GPTTrainerConfig( epochs=1, diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 6b6f1330dc..0851a4e2c6 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -89,7 +89,7 @@ use_ne_hifigan=True, ) audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, diffusion_sample_rate=24000, output_sample_rate=24000 + sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 ) config = GPTTrainerConfig( epochs=1, From b702b39b52753c731881ce01d3edbf7e88f06ed3 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 6 Nov 2023 19:02:09 +0100 Subject: [PATCH 2/4] Make style --- TTS/tts/models/base_tacotron.py | 7 ++++++- TTS/tts/models/tortoise.py | 7 ++++++- recipes/ljspeech/xtts_v1/train_gpt_xtts.py | 4 +--- recipes/ljspeech/xtts_v2/train_gpt_xtts.py | 4 +--- tests/xtts_tests/test_xtts_gpt_train.py | 4 +--- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 4 +--- 6 files changed, 16 insertions(+), 14 deletions(-) diff --git a/TTS/tts/models/base_tacotron.py b/TTS/tts/models/base_tacotron.py index 4aaf526111..f38dace235 100644 --- a/TTS/tts/models/base_tacotron.py +++ b/TTS/tts/models/base_tacotron.py @@ -252,7 +252,12 @@ def compute_gst(self, inputs, style_input, speaker_embedding=None): def compute_capacitron_VAE_embedding(self, inputs, reference_mel_info, text_info=None, speaker_embedding=None): """Capacitron Variational Autoencoder""" - (VAE_outputs, posterior_distribution, prior_distribution, capacitron_beta,) = self.capacitron_vae_layer( + ( + VAE_outputs, + posterior_distribution, + prior_distribution, + capacitron_beta, + ) = self.capacitron_vae_layer( reference_mel_info, text_info, speaker_embedding, # pylint: disable=not-callable diff --git a/TTS/tts/models/tortoise.py b/TTS/tts/models/tortoise.py index c8cfcfdd04..16644ff95e 100644 --- a/TTS/tts/models/tortoise.py +++ b/TTS/tts/models/tortoise.py @@ -676,7 +676,12 @@ def inference( ), "Too much text provided. Break the text up into separate segments and re-try inference." if voice_samples is not None: - (auto_conditioning, diffusion_conditioning, _, _,) = self.get_conditioning_latents( + ( + auto_conditioning, + diffusion_conditioning, + _, + _, + ) = self.get_conditioning_latents( voice_samples, return_mels=True, latent_averaging_mode=latent_averaging_mode, diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 65d3ccd04d..02e35dfd75 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -97,9 +97,7 @@ def main(): use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint ) # define audio config - audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 - ) + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) # training parameters config config = GPTTrainerConfig( output_path=OUT_PATH, diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index 3bb68e2f3f..4d06fed168 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -98,9 +98,7 @@ def main(): gpt_use_perceiver_resampler=True, ) # define audio config - audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 - ) + audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) # training parameters config config = GPTTrainerConfig( output_path=OUT_PATH, diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 09df98eff6..83cf537fb2 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -88,9 +88,7 @@ gpt_stop_audio_token=8193, use_ne_hifigan=True, ) -audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 -) +audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) config = GPTTrainerConfig( epochs=1, output_path=OUT_PATH, diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index 0851a4e2c6..b9f6438eef 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -88,9 +88,7 @@ gpt_use_perceiver_resampler=True, use_ne_hifigan=True, ) -audio_config = XttsAudioConfig( - sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000 -) +audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) config = GPTTrainerConfig( epochs=1, output_path=OUT_PATH, From 97b29a280e14914eebd21a4d9b65948ace10c760 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 6 Nov 2023 19:19:05 +0100 Subject: [PATCH 3/4] Drop diffusion deps in code --- TTS/tts/layers/xtts/vocoder.py | 385 ------------------- TTS/tts/models/xtts.py | 86 +---- recipes/ljspeech/xtts_v1/train_gpt_xtts.py | 1 - recipes/ljspeech/xtts_v2/train_gpt_xtts.py | 1 - tests/xtts_tests/test_xtts_gpt_train.py | 1 - tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 1 - 6 files changed, 16 insertions(+), 459 deletions(-) delete mode 100644 TTS/tts/layers/xtts/vocoder.py diff --git a/TTS/tts/layers/xtts/vocoder.py b/TTS/tts/layers/xtts/vocoder.py deleted file mode 100644 index 0f4991b886..0000000000 --- a/TTS/tts/layers/xtts/vocoder.py +++ /dev/null @@ -1,385 +0,0 @@ -import json -from dataclasses import dataclass -from enum import Enum -from typing import Callable, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -MAX_WAV_VALUE = 32768.0 - - -class KernelPredictor(torch.nn.Module): - """Kernel predictor for the location-variable convolutions""" - - def __init__( - self, - cond_channels, - conv_in_channels, - conv_out_channels, - conv_layers, - conv_kernel_size=3, - kpnet_hidden_channels=64, - kpnet_conv_size=3, - kpnet_dropout=0.0, - kpnet_nonlinear_activation="LeakyReLU", - kpnet_nonlinear_activation_params={"negative_slope": 0.1}, - ): - """ - Args: - cond_channels (int): number of channel for the conditioning sequence, - conv_in_channels (int): number of channel for the input sequence, - conv_out_channels (int): number of channel for the output sequence, - conv_layers (int): number of layers - """ - super().__init__() - - self.conv_in_channels = conv_in_channels - self.conv_out_channels = conv_out_channels - self.conv_kernel_size = conv_kernel_size - self.conv_layers = conv_layers - - kpnet_kernel_channels = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers # l_w - kpnet_bias_channels = conv_out_channels * conv_layers # l_b - - self.input_conv = nn.Sequential( - nn.utils.weight_norm(nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=2, bias=True)), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - ) - - self.residual_convs = nn.ModuleList() - padding = (kpnet_conv_size - 1) // 2 - for _ in range(3): - self.residual_convs.append( - nn.Sequential( - nn.Dropout(kpnet_dropout), - nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_hidden_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_hidden_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ), - getattr(nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), - ) - ) - self.kernel_conv = nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_kernel_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ) - self.bias_conv = nn.utils.weight_norm( - nn.Conv1d( - kpnet_hidden_channels, - kpnet_bias_channels, - kpnet_conv_size, - padding=padding, - bias=True, - ) - ) - - def forward(self, c): - """ - Args: - c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) - """ - batch, _, cond_length = c.shape - c = self.input_conv(c) - for residual_conv in self.residual_convs: - residual_conv.to(c.device) - c = c + residual_conv(c) - k = self.kernel_conv(c) - b = self.bias_conv(c) - kernels = k.contiguous().view( - batch, - self.conv_layers, - self.conv_in_channels, - self.conv_out_channels, - self.conv_kernel_size, - cond_length, - ) - bias = b.contiguous().view( - batch, - self.conv_layers, - self.conv_out_channels, - cond_length, - ) - - return kernels, bias - - def remove_weight_norm(self): - nn.utils.remove_weight_norm(self.input_conv[0]) - nn.utils.remove_weight_norm(self.kernel_conv) - nn.utils.remove_weight_norm(self.bias_conv) - for block in self.residual_convs: - nn.utils.remove_weight_norm(block[1]) - nn.utils.remove_weight_norm(block[3]) - - -class LVCBlock(torch.nn.Module): - """the location-variable convolutions""" - - def __init__( - self, - in_channels, - cond_channels, - stride, - dilations=[1, 3, 9, 27], - lReLU_slope=0.2, - conv_kernel_size=3, - cond_hop_length=256, - kpnet_hidden_channels=64, - kpnet_conv_size=3, - kpnet_dropout=0.0, - ): - super().__init__() - - self.cond_hop_length = cond_hop_length - self.conv_layers = len(dilations) - self.conv_kernel_size = conv_kernel_size - - self.kernel_predictor = KernelPredictor( - cond_channels=cond_channels, - conv_in_channels=in_channels, - conv_out_channels=2 * in_channels, - conv_layers=len(dilations), - conv_kernel_size=conv_kernel_size, - kpnet_hidden_channels=kpnet_hidden_channels, - kpnet_conv_size=kpnet_conv_size, - kpnet_dropout=kpnet_dropout, - kpnet_nonlinear_activation_params={"negative_slope": lReLU_slope}, - ) - - self.convt_pre = nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm( - nn.ConvTranspose1d( - in_channels, - in_channels, - 2 * stride, - stride=stride, - padding=stride // 2 + stride % 2, - output_padding=stride % 2, - ) - ), - ) - - self.conv_blocks = nn.ModuleList() - for dilation in dilations: - self.conv_blocks.append( - nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm( - nn.Conv1d( - in_channels, - in_channels, - conv_kernel_size, - padding=dilation * (conv_kernel_size - 1) // 2, - dilation=dilation, - ) - ), - nn.LeakyReLU(lReLU_slope), - ) - ) - - def forward(self, x, c): - """forward propagation of the location-variable convolutions. - Args: - x (Tensor): the input sequence (batch, in_channels, in_length) - c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) - - Returns: - Tensor: the output sequence (batch, in_channels, in_length) - """ - _, in_channels, _ = x.shape # (B, c_g, L') - - x = self.convt_pre(x) # (B, c_g, stride * L') - kernels, bias = self.kernel_predictor(c) - - for i, conv in enumerate(self.conv_blocks): - output = conv(x) # (B, c_g, stride * L') - - k = kernels[:, i, :, :, :, :] # (B, 2 * c_g, c_g, kernel_size, cond_length) - b = bias[:, i, :, :] # (B, 2 * c_g, cond_length) - - output = self.location_variable_convolution( - output, k, b, hop_size=self.cond_hop_length - ) # (B, 2 * c_g, stride * L'): LVC - x = x + torch.sigmoid(output[:, :in_channels, :]) * torch.tanh( - output[:, in_channels:, :] - ) # (B, c_g, stride * L'): GAU - - return x - - def location_variable_convolution(self, x, kernel, bias, dilation=1, hop_size=256): - """perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. - Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. - Args: - x (Tensor): the input sequence (batch, in_channels, in_length). - kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) - bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) - dilation (int): the dilation of convolution. - hop_size (int): the hop_size of the conditioning sequence. - Returns: - (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). - """ - batch, _, in_length = x.shape - batch, _, out_channels, kernel_size, kernel_length = kernel.shape - assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" - - padding = dilation * int((kernel_size - 1) / 2) - x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding) - x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) - - if hop_size < dilation: - x = F.pad(x, (0, dilation), "constant", 0) - x = x.unfold( - 3, dilation, dilation - ) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) - x = x[:, :, :, :, :hop_size] - x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) - x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) - - o = torch.einsum("bildsk,biokl->bolsd", x, kernel) - o = o.to(memory_format=torch.channels_last_3d) - bias = bias.unsqueeze(-1).unsqueeze(-1).to(memory_format=torch.channels_last_3d) - o = o + bias - o = o.contiguous().view(batch, out_channels, -1) - - return o - - def remove_weight_norm(self): - self.kernel_predictor.remove_weight_norm() - nn.utils.remove_weight_norm(self.convt_pre[1]) - for block in self.conv_blocks: - nn.utils.remove_weight_norm(block[1]) - - -class UnivNetGenerator(nn.Module): - """ - UnivNet Generator - - Originally from https://github.com/mindslab-ai/univnet/blob/master/model/generator.py. - """ - - def __init__( - self, - noise_dim=64, - channel_size=32, - dilations=[1, 3, 9, 27], - strides=[8, 8, 4], - lReLU_slope=0.2, - kpnet_conv_size=3, - # Below are MEL configurations options that this generator requires. - hop_length=256, - n_mel_channels=100, - ): - super(UnivNetGenerator, self).__init__() - self.mel_channel = n_mel_channels - self.noise_dim = noise_dim - self.hop_length = hop_length - channel_size = channel_size - kpnet_conv_size = kpnet_conv_size - - self.res_stack = nn.ModuleList() - hop_length = 1 - for stride in strides: - hop_length = stride * hop_length - self.res_stack.append( - LVCBlock( - channel_size, - n_mel_channels, - stride=stride, - dilations=dilations, - lReLU_slope=lReLU_slope, - cond_hop_length=hop_length, - kpnet_conv_size=kpnet_conv_size, - ) - ) - - self.conv_pre = nn.utils.weight_norm(nn.Conv1d(noise_dim, channel_size, 7, padding=3, padding_mode="reflect")) - - self.conv_post = nn.Sequential( - nn.LeakyReLU(lReLU_slope), - nn.utils.weight_norm(nn.Conv1d(channel_size, 1, 7, padding=3, padding_mode="reflect")), - nn.Tanh(), - ) - - def forward(self, c, z): - """ - Args: - c (Tensor): the conditioning sequence of mel-spectrogram (batch, mel_channels, in_length) - z (Tensor): the noise sequence (batch, noise_dim, in_length) - - """ - z = self.conv_pre(z) # (B, c_g, L) - - for res_block in self.res_stack: - res_block.to(z.device) - z = res_block(z, c) # (B, c_g, L * s_0 * ... * s_i) - - z = self.conv_post(z) # (B, 1, L * 256) - - return z - - def eval(self, inference=False): - super(UnivNetGenerator, self).eval() - # don't remove weight norm while validation in training loop - if inference: - self.remove_weight_norm() - - def remove_weight_norm(self): - nn.utils.remove_weight_norm(self.conv_pre) - - for layer in self.conv_post: - if len(layer.state_dict()) != 0: - nn.utils.remove_weight_norm(layer) - - for res_block in self.res_stack: - res_block.remove_weight_norm() - - def inference(self, c, z=None): - # pad input mel with zeros to cut artifact - # see https://github.com/seungwonpark/melgan/issues/8 - zero = torch.full((c.shape[0], self.mel_channel, 10), -11.5129).to(c.device) - mel = torch.cat((c, zero), dim=2) - - if z is None: - z = torch.randn(c.shape[0], self.noise_dim, mel.size(2)).to(mel.device) - - audio = self.forward(mel, z) - audio = audio[:, :, : -(self.hop_length * 10)] - audio = audio.clamp(min=-1, max=1) - return audio - - -if __name__ == "__main__": - model = UnivNetGenerator() - - c = torch.randn(3, 100, 10) - z = torch.randn(3, 64, 10) - print(c.shape) - - y = model(c, z) - print(y.shape) - assert y.shape == torch.Size([3, 1, 2560]) - - pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - print(pytorch_total_params) diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index 6b8a73e859..af94675be9 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -13,7 +13,6 @@ from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer -from TTS.tts.layers.xtts.vocoder import UnivNetGenerator from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -185,7 +184,6 @@ class XttsArgs(Coqpit): clvp_checkpoint (str, optional): The checkpoint for the ConditionalLatentVariablePerseq model. Defaults to None. decoder_checkpoint (str, optional): The checkpoint for the DiffTTS model. Defaults to None. num_chars (int, optional): The maximum number of characters to generate. Defaults to 255. - use_hifigan (bool, optional): Whether to use hifigan with implicit enhancement or diffusion + univnet as a decoder. Defaults to True. For GPT model: gpt_max_audio_tokens (int, optional): The maximum mel tokens for the autoregressive model. Defaults to 604. @@ -223,7 +221,6 @@ class XttsArgs(Coqpit): clvp_checkpoint: str = None decoder_checkpoint: str = None num_chars: int = 255 - use_hifigan: bool = True # XTTS GPT Encoder params tokenizer_file: str = "" @@ -320,32 +317,15 @@ def init_models(self): code_stride_len=self.args.gpt_code_stride_len, ) - if self.args.use_hifigan: - self.hifigan_decoder = HifiDecoder( - input_sample_rate=self.args.input_sample_rate, - output_sample_rate=self.args.output_sample_rate, - output_hop_length=self.args.output_hop_length, - ar_mel_length_compression=self.args.gpt_code_stride_len, - decoder_input_dim=self.args.decoder_input_dim, - d_vector_dim=self.args.d_vector_dim, - cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, - ) - - if not self.args.use_hifigan: - self.diffusion_decoder = DiffusionTts( - model_channels=self.args.diff_model_channels, - num_layers=self.args.diff_num_layers, - in_channels=self.args.diff_in_channels, - out_channels=self.args.diff_out_channels, - in_latent_channels=self.args.diff_in_latent_channels, - in_tokens=self.args.diff_in_tokens, - dropout=self.args.diff_dropout, - use_fp16=self.args.diff_use_fp16, - num_heads=self.args.diff_num_heads, - layer_drop=self.args.diff_layer_drop, - unconditioned_percentage=self.args.diff_unconditioned_percentage, - ) - self.vocoder = UnivNetGenerator() + self.hifigan_decoder = HifiDecoder( + input_sample_rate=self.args.input_sample_rate, + output_sample_rate=self.args.output_sample_rate, + output_hop_length=self.args.output_hop_length, + ar_mel_length_compression=self.args.gpt_code_stride_len, + decoder_input_dim=self.args.decoder_input_dim, + d_vector_dim=self.args.d_vector_dim, + cond_d_vector_in_each_upsampling_layer=self.args.cond_d_vector_in_each_upsampling_layer, + ) @property def device(self): @@ -426,7 +406,6 @@ def get_conditioning_latents( sound_norm_refs=False, ): speaker_embedding = None - diffusion_cond_latents = None audio, sr = torchaudio.load(audio_path) audio = audio[:, : sr * max_ref_length].to(self.device) @@ -437,12 +416,9 @@ def get_conditioning_latents( if librosa_trim_db is not None: audio = librosa.effects.trim(audio, top_db=librosa_trim_db)[0] - if self.args.use_hifigan or self.args.use_hifigan: - speaker_embedding = self.get_speaker_embedding(audio, sr) - else: - diffusion_cond_latents = self.get_diffusion_cond_latents(audio, sr) + speaker_embedding = self.get_speaker_embedding(audio, sr) gpt_cond_latents = self.get_gpt_cond_latents(audio, sr, length=gpt_cond_len) # [1, 1024, T] - return gpt_cond_latents, diffusion_cond_latents, speaker_embedding + return gpt_cond_latents, speaker_embedding def synthesize(self, text, config, speaker_wav, language, **kwargs): """Synthesize speech with the given input text. @@ -575,7 +551,7 @@ def full_inference( Generated audio clip(s) as a torch tensor. Shape 1,S if k=1 else, (k,1,S) where S is the sample length. Sample rate is 24kHz. """ - (gpt_cond_latent, diffusion_conditioning, speaker_embedding) = self.get_conditioning_latents( + (gpt_cond_latent, speaker_embedding) = self.get_conditioning_latents( audio_path=ref_audio_path, gpt_cond_len=gpt_cond_len, max_ref_length=max_ref_len, @@ -587,7 +563,6 @@ def full_inference( language, gpt_cond_latent, speaker_embedding, - diffusion_conditioning, temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, @@ -610,7 +585,6 @@ def inference( language, gpt_cond_latent, speaker_embedding, - diffusion_conditioning, # GPT inference temperature=0.65, length_penalty=1, @@ -639,14 +613,6 @@ def inference( text_tokens.shape[-1] < self.args.gpt_max_text_tokens ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - if not self.args.use_hifigan: - diffuser = load_discrete_vocoder_diffuser( - desired_diffusion_steps=decoder_iterations, - cond_free=cond_free, - cond_free_k=cond_free_k, - sampler=decoder_sampler, - ) - with torch.no_grad(): gpt_codes = self.gpt.generate( cond_latents=gpt_cond_latent, @@ -688,11 +654,7 @@ def inference( gpt_latents = gpt_latents[:, :k] break - if decoder == "hifigan": - assert ( - hasattr(self, "hifigan_decoder") and self.hifigan_decoder is not None - ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" - wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) return { "wav": wav.cpu().numpy().squeeze(), @@ -735,9 +697,6 @@ def inference_stream( decoder="hifigan", **hf_generate_kwargs, ): - assert hasattr( - self, "hifigan_decoder" - ), "`inference_stream` requires use_hifigan to be set to true in the config.model_args, diffusion is too slow to stream." text = text.strip().lower() text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) @@ -776,13 +735,7 @@ def inference_stream( if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): gpt_latents = torch.cat(all_latents, dim=0)[None, :] - if decoder == "hifigan": - assert hasattr( - self, "hifigan_decoder" - ), "You must enable hifigan decoder to use it by setting config `use_hifigan: true`" - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - else: - raise NotImplementedError("Diffusion for streaming inference not implemented.") + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len ) @@ -810,10 +763,8 @@ def eval(self): # pylint: disable=redefined-builtin def get_compatible_checkpoint_state_dict(self, model_path): checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"] - ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else [] - ignore_keys += [] if self.args.use_hifigan else ["hifigan_decoder"] # remove xtts gpt trainer extra keys - ignore_keys += ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] + ignore_keys = ["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"] for key in list(checkpoint.keys()): # check if it is from the coqui Trainer if so convert it if key.startswith("xtts."): @@ -872,12 +823,7 @@ def load_checkpoint( self.load_state_dict(checkpoint, strict=strict) if eval: - if hasattr(self, "hifigan_decoder"): - self.hifigan_decoder.eval() - if hasattr(self, "diffusion_decoder"): - self.diffusion_decoder.eval() - if hasattr(self, "vocoder"): - self.vocoder.eval() + self.hifigan_decoder.eval() self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache, use_deepspeed=use_deepspeed) self.gpt.eval() diff --git a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py index 02e35dfd75..268a033535 100644 --- a/recipes/ljspeech/xtts_v1/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v1/train_gpt_xtts.py @@ -94,7 +94,6 @@ def main(): gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint ) # define audio config audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) diff --git a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py index 4d06fed168..d94204ca4d 100644 --- a/recipes/ljspeech/xtts_v2/train_gpt_xtts.py +++ b/recipes/ljspeech/xtts_v2/train_gpt_xtts.py @@ -93,7 +93,6 @@ def main(): gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, # if it is true it will keep the non-enhanced keys on the output checkpoint gpt_use_masking_gt_prompt_approach=True, gpt_use_perceiver_resampler=True, ) diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 83cf537fb2..03514daa3b 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -86,7 +86,6 @@ gpt_num_audio_tokens=8194, gpt_start_audio_token=8192, gpt_stop_audio_token=8193, - use_ne_hifigan=True, ) audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) config = GPTTrainerConfig( diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index b9f6438eef..8099503855 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -86,7 +86,6 @@ gpt_stop_audio_token=8193, gpt_use_masking_gt_prompt_approach=True, gpt_use_perceiver_resampler=True, - use_ne_hifigan=True, ) audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) config = GPTTrainerConfig( From 5e72089906f5ef2c0765035511fb56ad59516109 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Mon, 6 Nov 2023 19:51:59 +0100 Subject: [PATCH 4/4] Restore thrashed --- TTS/tts/layers/tortoise/dpm_solver.py | 1551 +++++++++++++++++++++++++ 1 file changed, 1551 insertions(+) create mode 100644 TTS/tts/layers/tortoise/dpm_solver.py diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py new file mode 100644 index 0000000000..2166eebb3c --- /dev/null +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -0,0 +1,1551 @@ +import math + +import torch + + +class NoiseScheduleVP: + def __init__( + self, + schedule="discrete", + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20.0, + dtype=torch.float32, + ): + """Create a wrapper class for the forward SDE (VP type). + + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + + t = self.inverse_lambda(lambda_t) + + =============================================================== + + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + + 1. For discrete-time DPMs: + + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + + Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + + + 2. For continuous-time DPMs: + + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + + =============================================================== + + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + + Example: + + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + + """ + + if schedule not in ["discrete", "linear", "cosine"]: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule + ) + ) + + self.schedule = schedule + if schedule == "discrete": + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1.0 + self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) + self.log_alpha_array = log_alphas.reshape( + ( + 1, + -1, + ) + ).to(dtype=dtype) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999.0 + self.cosine_t_max = ( + math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0)) + self.schedule = schedule + if schedule == "cosine": + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1.0 + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == "discrete": + return interpolate_fn( + t.reshape((-1, 1)), + self.t_array.to(t.device), + self.log_alpha_array.to(t.device), + ).reshape((-1)) + elif self.schedule == "linear": + return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == "cosine": + + def log_alpha_fn(s): + return torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0)) + + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == "linear": + tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0**2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == "discrete": + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb) + t = interpolate_fn( + log_alpha.reshape((-1, 1)), + torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1]), + ) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb)) + + def t_fn(log_alpha_t): + return ( + torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) + * 2.0 + * (1.0 + self.cosine_s) + / math.pi + - self.cosine_s + ) + + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1.0, + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + + We support four types of the diffusion model by setting `model_type`: + + 1. "noise": noise prediction model. (Trained by predicting noise). + + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + + =============================================================== + + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == "discrete": + return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0 + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return (x - alpha_t * output) / sigma_t + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + return alpha_t * output + sigma_t * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + return -sigma_t * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * sigma_t * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1.0 or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v", "score"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__( + self, + model_fn, + noise_schedule, + algorithm_type="dpmsolver++", + correcting_x0_fn=None, + correcting_xt_fn=None, + thresholding_max_val=1.0, + dynamic_thresholding_ratio=0.995, + ): + """Construct a DPM-Solver. + + We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). + + We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you + can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the + dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space + DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space + DPMs (such as stable-diffusion). + + To support advanced algorithms in image-to-image applications, we also support corrector functions for + both x0 and xt. + + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". + correcting_x0_fn: A `str` or a function with the following format: + ``` + def correcting_x0_fn(x0, t): + x0_new = ... + return x0_new + ``` + This function is to correct the outputs of the data prediction model at each sampling step. e.g., + ``` + x0_pred = data_pred_model(xt, t) + if correcting_x0_fn is not None: + x0_pred = correcting_x0_fn(x0_pred, t) + xt_1 = update(x0_pred, xt, t) + ``` + If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. + correcting_xt_fn: A function with the following format: + ``` + def correcting_xt_fn(xt, t, step): + x_new = ... + return x_new + ``` + This function is to correct the intermediate samples xt at each sampling step. e.g., + ``` + xt = ... + xt = correcting_xt_fn(xt, t, step) + ``` + thresholding_max_val: A `float`. The max value for thresholding. + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). + Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, + Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models + with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) + self.noise_schedule = noise_schedule + assert algorithm_type in ["dpmsolver", "dpmsolver++"] + self.algorithm_type = algorithm_type + if correcting_x0_fn == "dynamic_thresholding": + self.correcting_x0_fn = self.dynamic_thresholding_fn + else: + self.correcting_x0_fn = correcting_x0_fn + self.correcting_xt_fn = correcting_xt_fn + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.thresholding_max_val = thresholding_max_val + + def dynamic_thresholding_fn(self, x0, t): + """ + The dynamic thresholding method. + """ + dims = x0.dim() + p = self.dynamic_thresholding_ratio + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims( + torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), + dims, + ) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with corrector). + """ + noise = self.noise_prediction_fn(x, t) + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - sigma_t * noise) / alpha_t + if self.correcting_x0_fn is not None: + x0 = self.correcting_x0_fn(x0, t) + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.algorithm_type == "dpmsolver++": + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == "logSNR": + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == "time_uniform": + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == "time_quadratic": + t_order = 2 + t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type) + ) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3,] * ( + K - 2 + ) + [2, 1] + elif steps % 3 == 1: + orders = [3,] * ( + K - 1 + ) + [1] + else: + orders = [3,] * ( + K - 1 + ) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [ + 2, + ] * K + else: + K = steps // 2 + 1 + orders = [2,] * ( + K - 1 + ) + [1] + elif order == 1: + K = 1 + orders = [ + 1, + ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == "logSNR": + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum( + torch.tensor( + [ + 0, + ] + + orders + ), + 0, + ).to(device) + ] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s + if return_intermediate: + return x_t, {"model_s": model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update( + self, + x, + s, + t, + r1=0.5, + model_s=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(t), + ) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + if solver_type == "dpmsolver": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) + ) + elif solver_type == "taylor": + x_t = ( + torch.exp(log_alpha_t - log_alpha_s) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update( + self, + x, + s, + t, + r1=1.0 / 3.0, + r2=2.0 / 3.0, + model_s=None, + model_s1=None, + return_intermediate=False, + solver_type="dpmsolver", + ): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1.0 / 3.0 + if r2 is None: + r2 = 2.0 / 3.0 + ns = self.noise_schedule + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ( + ns.marginal_log_mean_coeff(s), + ns.marginal_log_mean_coeff(s1), + ns.marginal_log_mean_coeff(s2), + ns.marginal_log_mean_coeff(t), + ) + sigma_s, sigma_s1, sigma_s2, sigma_t = ( + ns.marginal_std(s), + ns.marginal_std(s1), + ns.marginal_std(s2), + ns.marginal_std(t), + ) + alpha_s1, alpha_s2, alpha_t = ( + torch.exp(log_alpha_s1), + torch.exp(log_alpha_s2), + torch.exp(log_alpha_t), + ) + + if self.algorithm_type == "dpmsolver++": + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0 + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (sigma_s2 / sigma_s) * x + - (alpha_s2 * phi_12) * model_s + + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (sigma_t / sigma_s) * x + - (alpha_t * phi_1) * model_s + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0 + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + (torch.exp(log_alpha_s2 - log_alpha_s)) * x + - (sigma_s2 * phi_12) * model_s + - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s) + ) + elif solver_type == "taylor": + D1_0 = (1.0 / r1) * (model_s1 - model_s) + D1_1 = (1.0 / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + (torch.exp(log_alpha_t - log_alpha_s)) * x + - (sigma_t * phi_1) * model_s + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + + if return_intermediate: + return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ["dpmsolver", "taylor"]: + raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] + t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] + lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + if solver_type == "dpmsolver": + x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0 + elif solver_type == "taylor": + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * (phi_1 / h + 1.0)) * D1_0 + ) + else: + phi_1 = torch.expm1(h) + if solver_type == "dpmsolver": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - 0.5 * (sigma_t * phi_1) * D1_0 + ) + elif solver_type == "taylor": + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * (phi_1 / h - 1.0)) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ( + ns.marginal_lambda(t_prev_2), + ns.marginal_lambda(t_prev_1), + ns.marginal_lambda(t_prev_0), + ns.marginal_lambda(t), + ) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1) + D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.algorithm_type == "dpmsolver++": + phi_1 = torch.expm1(-h) + phi_2 = phi_1 / h + 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (sigma_t / sigma_prev_0) * x + - (alpha_t * phi_1) * model_prev_0 + + (alpha_t * phi_2) * D1 + - (alpha_t * phi_3) * D2 + ) + else: + phi_1 = torch.expm1(h) + phi_2 = phi_1 / h - 1.0 + phi_3 = phi_2 / h - 0.5 + x_t = ( + (torch.exp(log_alpha_t - log_alpha_prev_0)) * x + - (sigma_t * phi_1) * model_prev_0 + - (sigma_t * phi_2) * D1 + - (sigma_t * phi_3) * D2 + ) + return x_t + + def singlestep_dpm_solver_update( + self, + x, + s, + t, + order, + return_intermediate=False, + solver_type="dpmsolver", + r1=None, + r2=None, + ): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (1,). + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + ) + elif order == 3: + return self.singlestep_dpm_solver_third_update( + x, + s, + t, + return_intermediate=return_intermediate, + solver_type=solver_type, + r1=r1, + r2=r2, + ) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) + t: A pytorch tensor. The ending time, with the shape (1,). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive( + self, + x, + order, + t_T, + t_0, + h_init=0.05, + atol=0.0078, + rtol=0.05, + theta=0.9, + t_err=1e-5, + solver_type="dpmsolver", + ): + """ + The adaptive step size solver based on singlestep DPM-Solver. + + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpmsolver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((1,)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + + def lower_update(x, s, t): + return self.dpm_solver_first_update(x, s, t, return_intermediate=True) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) + + elif order == 3: + r1, r2 = 1.0 / 3.0, 2.0 / 3.0 + + def lower_update(x, s, t): + return self.singlestep_dpm_solver_second_update( + x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type + ) + + def higher_update(x, s, t, **kwargs): + return self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) + + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max( + torch.ones_like(x).to(x) * atol, + rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)), + ) + + def norm_fn(v): + return torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.0): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min( + theta * h * torch.float_power(E, -1.0 / order).float(), + lambda_0 - lambda_s, + ) + nfe += order + print("adaptive solver nfe", nfe) + return x + + def add_noise(self, x, t, noise=None): + """ + Compute the noised input xt = alpha_t * x + sigma_t * noise. + + Args: + x: A `torch.Tensor` with shape `(batch_size, *shape)`. + t: A `torch.Tensor` with shape `(t_size,)`. + Returns: + xt with shape `(t_size, batch_size, *shape)`. + """ + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + if noise is None: + noise = torch.randn((t.shape[0], *x.shape), device=x.device) + x = x.reshape((-1, *x.shape)) + xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise + if t.shape[0] == 1: + return xt.squeeze(0) + else: + return xt + + def inverse( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. + For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start + t_T = self.noise_schedule.T if t_end is None else t_end + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + return self.sample( + x, + steps=steps, + t_start=t_0, + t_end=t_T, + order=order, + skip_type=skip_type, + method=method, + lower_order_final=lower_order_final, + denoise_to_zero=denoise_to_zero, + solver_type=solver_type, + atol=atol, + rtol=rtol, + return_intermediate=return_intermediate, + ) + + def sample( + self, + x, + steps=20, + t_start=None, + t_end=None, + order=2, + skip_type="time_uniform", + method="multistep", + lower_order_final=True, + denoise_to_zero=False, + solver_type="dpmsolver", + atol=0.0078, + rtol=0.05, + return_intermediate=False, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + + ===================================================== + + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + + ===================================================== + + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g., DPM-Solver: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + e.g., DPM-Solver++: + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + return_intermediate: A `bool`. Whether to save the xt at each step. + When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + + """ + t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + assert ( + t_0 > 0 and t_T > 0 + ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" + if return_intermediate: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when saving intermediate values" + if self.correcting_xt_fn is not None: + assert method in [ + "multistep", + "singlestep", + "singlestep_fixed", + ], "Cannot use adaptive solver when correcting_xt_fn is not None" + device = x.device + intermediates = [] + with torch.no_grad(): + if method == "adaptive": + x = self.dpm_solver_adaptive( + x, + order=order, + t_T=t_T, + t_0=t_0, + atol=atol, + rtol=rtol, + solver_type=solver_type, + ) + elif method == "multistep": + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + # Init the initial values. + step = 0 + t = timesteps[step] + t_prev_list = [t] + model_prev_list = [self.model_fn(x, t)] + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + # Init the first `order` values by lower order multistep DPM-Solver. + for step in range(1, order): + t = timesteps[step] + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + t_prev_list.append(t) + model_prev_list.append(self.model_fn(x, t)) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in range(order, steps + 1): + t = timesteps[step] + # We only use lower order for steps < 10 + if lower_order_final and steps < 10: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update( + x, + model_prev_list, + t_prev_list, + t, + step_order, + solver_type=solver_type, + ) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, t) + elif method in ["singlestep", "singlestep_fixed"]: + if method == "singlestep": + (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( + steps=steps, + order=order, + skip_type=skip_type, + t_T=t_T, + t_0=t_0, + device=device, + ) + elif method == "singlestep_fixed": + K = steps // order + orders = [ + order, + ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for step, order in enumerate(orders): + s, t = timesteps_outer[step], timesteps_outer[step + 1] + timesteps_inner = self.get_time_steps( + skip_type=skip_type, + t_T=s.item(), + t_0=t.item(), + N=order, + device=device, + ) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step) + if return_intermediate: + intermediates.append(x) + else: + raise ValueError("Got wrong method {}".format(method)) + if denoise_to_zero: + t = torch.ones((1,)).to(device) * t_0 + x = self.denoise_to_zero_fn(x, t) + if self.correcting_xt_fn is not None: + x = self.correcting_xt_fn(x, t, step + 1) + if return_intermediate: + intermediates.append(x) + if return_intermediate: + return x, intermediates + else: + return x + + +############################################################# +# other utility functions +############################################################# + + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), + torch.tensor(K - 2, device=x.device), + cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file