Spaces:
Runtime error
Runtime error
| # Copyright 2022 Stanford University Team and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion | |
| # and https://github.com/hojonathanho/diffusion | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, Union | |
| import flax | |
| import jax.numpy as jnp | |
| from ..configuration_utils import ConfigMixin, register_to_config | |
| from ..utils import deprecate | |
| from .scheduling_utils_flax import ( | |
| _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, | |
| FlaxSchedulerMixin, | |
| FlaxSchedulerOutput, | |
| broadcast_to_shape_from_left, | |
| ) | |
| def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: | |
| """ | |
| Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of | |
| (1-beta) over time from t = [0,1]. | |
| Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up | |
| to that part of the diffusion process. | |
| Args: | |
| num_diffusion_timesteps (`int`): the number of betas to produce. | |
| max_beta (`float`): the maximum beta to use; use values lower than 1 to | |
| prevent singularities. | |
| Returns: | |
| betas (`jnp.ndarray`): the betas used by the scheduler to step the model outputs | |
| """ | |
| def alpha_bar(time_step): | |
| return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 | |
| betas = [] | |
| for i in range(num_diffusion_timesteps): | |
| t1 = i / num_diffusion_timesteps | |
| t2 = (i + 1) / num_diffusion_timesteps | |
| betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) | |
| return jnp.array(betas, dtype=jnp.float32) | |
| class DDIMSchedulerState: | |
| # setable values | |
| timesteps: jnp.ndarray | |
| alphas_cumprod: jnp.ndarray | |
| num_inference_steps: Optional[int] = None | |
| def create(cls, num_train_timesteps: int, alphas_cumprod: jnp.ndarray): | |
| return cls(timesteps=jnp.arange(0, num_train_timesteps)[::-1], alphas_cumprod=alphas_cumprod) | |
| class FlaxDDIMSchedulerOutput(FlaxSchedulerOutput): | |
| state: DDIMSchedulerState | |
| class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): | |
| """ | |
| Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising | |
| diffusion probabilistic models (DDPMs) with non-Markovian guidance. | |
| [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` | |
| function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. | |
| [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and | |
| [`~SchedulerMixin.from_pretrained`] functions. | |
| For more details, see the original paper: https://arxiv.org/abs/2010.02502 | |
| Args: | |
| num_train_timesteps (`int`): number of diffusion steps used to train the model. | |
| beta_start (`float`): the starting `beta` value of inference. | |
| beta_end (`float`): the final `beta` value. | |
| beta_schedule (`str`): | |
| the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from | |
| `linear`, `scaled_linear`, or `squaredcos_cap_v2`. | |
| trained_betas (`jnp.ndarray`, optional): | |
| option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. | |
| clip_sample (`bool`, default `True`): | |
| option to clip predicted sample between -1 and 1 for numerical stability. | |
| set_alpha_to_one (`bool`, default `True`): | |
| each diffusion step uses the value of alphas product at that step and at the previous one. For the final | |
| step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, | |
| otherwise it uses the value of alpha at step 0. | |
| steps_offset (`int`, default `0`): | |
| an offset added to the inference steps. You can use a combination of `offset=1` and | |
| `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in | |
| stable diffusion. | |
| prediction_type (`str`, default `epsilon`): | |
| indicates whether the model predicts the noise (epsilon), or the samples. One of `epsilon`, `sample`. | |
| `v-prediction` is not supported for this scheduler. | |
| """ | |
| _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() | |
| _deprecated_kwargs = ["predict_epsilon"] | |
| def has_state(self): | |
| return True | |
| def __init__( | |
| self, | |
| num_train_timesteps: int = 1000, | |
| beta_start: float = 0.0001, | |
| beta_end: float = 0.02, | |
| beta_schedule: str = "linear", | |
| set_alpha_to_one: bool = True, | |
| steps_offset: int = 0, | |
| prediction_type: str = "epsilon", | |
| **kwargs, | |
| ): | |
| message = ( | |
| "Please make sure to instantiate your scheduler with `prediction_type` instead. E.g. `scheduler =" | |
| " FlaxDDIMScheduler.from_pretrained(<model_id>, prediction_type='epsilon')`." | |
| ) | |
| predict_epsilon = deprecate("predict_epsilon", "0.11.0", message, take_from=kwargs) | |
| if predict_epsilon is not None: | |
| self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") | |
| if beta_schedule == "linear": | |
| self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) | |
| elif beta_schedule == "scaled_linear": | |
| # this schedule is very specific to the latent diffusion model. | |
| self.betas = jnp.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=jnp.float32) ** 2 | |
| elif beta_schedule == "squaredcos_cap_v2": | |
| # Glide cosine schedule | |
| self.betas = betas_for_alpha_bar(num_train_timesteps) | |
| else: | |
| raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | |
| self.alphas = 1.0 - self.betas | |
| # HACK for now - clean up later (PVP) | |
| self._alphas_cumprod = jnp.cumprod(self.alphas, axis=0) | |
| # At every step in ddim, we are looking into the previous alphas_cumprod | |
| # For the final step, there is no previous alphas_cumprod because we are already at 0 | |
| # `set_alpha_to_one` decides whether we set this parameter simply to one or | |
| # whether we use the final alpha of the "non-previous" one. | |
| self.final_alpha_cumprod = jnp.array(1.0) if set_alpha_to_one else float(self._alphas_cumprod[0]) | |
| # standard deviation of the initial noise distribution | |
| self.init_noise_sigma = 1.0 | |
| def scale_model_input( | |
| self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None | |
| ) -> jnp.ndarray: | |
| """ | |
| Args: | |
| state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. | |
| sample (`jnp.ndarray`): input sample | |
| timestep (`int`, optional): current timestep | |
| Returns: | |
| `jnp.ndarray`: scaled input sample | |
| """ | |
| return sample | |
| def create_state(self): | |
| return DDIMSchedulerState.create( | |
| num_train_timesteps=self.config.num_train_timesteps, alphas_cumprod=self._alphas_cumprod | |
| ) | |
| def _get_variance(self, timestep, prev_timestep, alphas_cumprod): | |
| alpha_prod_t = alphas_cumprod[timestep] | |
| alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) | |
| beta_prod_t = 1 - alpha_prod_t | |
| beta_prod_t_prev = 1 - alpha_prod_t_prev | |
| variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) | |
| return variance | |
| def set_timesteps( | |
| self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = () | |
| ) -> DDIMSchedulerState: | |
| """ | |
| Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. | |
| Args: | |
| state (`DDIMSchedulerState`): | |
| the `FlaxDDIMScheduler` state data class instance. | |
| num_inference_steps (`int`): | |
| the number of diffusion steps used when generating samples with a pre-trained model. | |
| """ | |
| offset = self.config.steps_offset | |
| step_ratio = self.config.num_train_timesteps // num_inference_steps | |
| # creates integer timesteps by multiplying by ratio | |
| # casting to int to avoid issues when num_inference_step is power of 3 | |
| timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1] | |
| timesteps = timesteps + offset | |
| return state.replace(num_inference_steps=num_inference_steps, timesteps=timesteps) | |
| def step( | |
| self, | |
| state: DDIMSchedulerState, | |
| model_output: jnp.ndarray, | |
| timestep: int, | |
| sample: jnp.ndarray, | |
| return_dict: bool = True, | |
| ) -> Union[FlaxDDIMSchedulerOutput, Tuple]: | |
| """ | |
| Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion | |
| process from the learned model outputs (most often the predicted noise). | |
| Args: | |
| state (`DDIMSchedulerState`): the `FlaxDDIMScheduler` state data class instance. | |
| model_output (`jnp.ndarray`): direct output from learned diffusion model. | |
| timestep (`int`): current discrete timestep in the diffusion chain. | |
| sample (`jnp.ndarray`): | |
| current instance of sample being created by diffusion process. | |
| return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class | |
| Returns: | |
| [`FlaxDDIMSchedulerOutput`] or `tuple`: [`FlaxDDIMSchedulerOutput`] if `return_dict` is True, otherwise a | |
| `tuple`. When returning a tuple, the first element is the sample tensor. | |
| """ | |
| if state.num_inference_steps is None: | |
| raise ValueError( | |
| "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" | |
| ) | |
| # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf | |
| # Ideally, read DDIM paper in-detail understanding | |
| # Notation (<variable name> -> <name in paper> | |
| # - pred_noise_t -> e_theta(x_t, t) | |
| # - pred_original_sample -> f_theta(x_t, t) or x_0 | |
| # - std_dev_t -> sigma_t | |
| # - eta -> η | |
| # - pred_sample_direction -> "direction pointing to x_t" | |
| # - pred_prev_sample -> "x_t-1" | |
| # TODO(Patrick) - eta is always 0.0 for now, allow to be set in step function | |
| eta = 0.0 | |
| # 1. get previous step value (=t-1) | |
| prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps | |
| alphas_cumprod = state.alphas_cumprod | |
| # 2. compute alphas, betas | |
| alpha_prod_t = alphas_cumprod[timestep] | |
| alpha_prod_t_prev = jnp.where(prev_timestep >= 0, alphas_cumprod[prev_timestep], self.final_alpha_cumprod) | |
| beta_prod_t = 1 - alpha_prod_t | |
| # 3. compute predicted original sample from predicted noise also called | |
| # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| if self.config.prediction_type == "epsilon": | |
| pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) | |
| elif self.config.prediction_type == "sample": | |
| pred_original_sample = model_output | |
| elif self.config.prediction_type == "v_prediction": | |
| pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output | |
| # predict V | |
| model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample | |
| else: | |
| raise ValueError( | |
| f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" | |
| " `v_prediction`" | |
| ) | |
| # 4. compute variance: "sigma_t(η)" -> see formula (16) | |
| # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) | |
| variance = self._get_variance(timestep, prev_timestep, alphas_cumprod) | |
| std_dev_t = eta * variance ** (0.5) | |
| # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output | |
| # 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf | |
| prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction | |
| if not return_dict: | |
| return (prev_sample, state) | |
| return FlaxDDIMSchedulerOutput(prev_sample=prev_sample, state=state) | |
| def add_noise( | |
| self, | |
| original_samples: jnp.ndarray, | |
| noise: jnp.ndarray, | |
| timesteps: jnp.ndarray, | |
| ) -> jnp.ndarray: | |
| sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 | |
| sqrt_alpha_prod = sqrt_alpha_prod.flatten() | |
| sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape) | |
| sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0 | |
| sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() | |
| sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape) | |
| noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise | |
| return noisy_samples | |
| def __len__(self): | |
| return self.config.num_train_timesteps | |