Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul
835a087a47 Merge branch 'main' into apply-lora-scale-decorator 2026-01-20 10:44:21 +05:30
sayakpaul
afa4a23c6c feat: implement apply_lora_scale to remove boilerplate. 2026-01-19 10:04:24 +05:30
31 changed files with 172 additions and 203 deletions

View File

@@ -478,7 +478,7 @@ class PeftAdapterMixin:
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
weights (`Union[List[float], float]`, *optional*):
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -495,7 +495,7 @@ class PeftAdapterMixin:
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.unet.set_adapters(["cinematic", "pixel"], weights=[0.5, 0.5])
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if not USE_PEFT_BACKEND:

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -634,6 +634,7 @@ class FluxTransformer2DModel(
self.gradient_checkpointing = False
@apply_lora_scale("joint_attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
@@ -675,20 +676,6 @@ class FluxTransformer2DModel(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
@@ -785,10 +772,6 @@ class FluxTransformer2DModel(
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)

View File

@@ -1552,11 +1552,11 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self._blocks = blocks
self.blocks = blocks
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self._blocks.expected_configs}
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
# update component_specs and config_specs based on modular_model_index.json
if modular_config_dict is not None:
@@ -1603,9 +1603,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for name, config_spec in self._config_specs.items():
default_configs[name] = config_spec.default
self.register_to_config(**default_configs)
self.register_to_config(
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None)
@property
def default_call_parameters(self) -> Dict[str, Any]:
@@ -1614,7 +1612,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- Dictionary mapping input names to their default values
"""
params = {}
for input_param in self._blocks.inputs:
for input_param in self.blocks.inputs:
params[input_param.name] = input_param.default
return params
@@ -1777,15 +1775,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Returns:
- The docstring of the pipeline blocks
"""
return self._blocks.doc
@property
def blocks(self) -> ModularPipelineBlocks:
"""
Returns:
- A copy of the pipeline blocks
"""
return deepcopy(self._blocks)
return self.blocks.doc
def register_components(self, **kwargs):
"""
@@ -2519,7 +2509,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
def set_progress_bar_config(self, **kwargs):
for sub_block_name, sub_block in self._blocks.sub_blocks.items():
for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs)
@@ -2573,7 +2563,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self._blocks.inputs:
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
@@ -2592,9 +2582,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
# Run the pipeline
with torch.no_grad():
try:
_, state = self._blocks(self, state)
_, state = self.blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self._blocks.__class__.__name__}):\n"
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise

View File

@@ -14,7 +14,7 @@ from .scheduling_utils import SchedulerMixin
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -28,8 +28,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -65,8 +65,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -51,7 +51,7 @@ class DDIMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -65,8 +65,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:
@@ -100,13 +100,14 @@ def betas_for_alpha_bar(
return torch.tensor(betas, dtype=torch.float32)
def rescale_zero_terminal_snr(alphas_cumprod: torch.Tensor) -> torch.Tensor:
def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on (Algorithm 1)[https://huggingface.co/papers/2305.08891]
Rescales betas to have zero terminal SNR Based on https://huggingface.co/papers/2305.08891 (Algorithm 1)
Args:
alphas_cumprod (`torch.Tensor`):
The alphas cumulative products that the scheduler is being initialized with.
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
@@ -141,11 +142,11 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
beta_start (`float`, defaults to 0.00085):
beta_start (`float`, defaults to 0.0001):
The starting `beta` value of inference.
beta_end (`float`, defaults to 0.0120):
beta_end (`float`, defaults to 0.02):
The final `beta` value.
beta_schedule (`str`, defaults to `"scaled_linear"`):
beta_schedule (`str`, defaults to `"linear"`):
The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear`, `scaled_linear`, or `squaredcos_cap_v2`.
trained_betas (`np.ndarray`, *optional*):
@@ -178,8 +179,6 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
dark samples instead of limiting it to samples with medium brightness. Loosely related to
[`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
snr_shift_scale (`float`, defaults to 3.0):
Shift scale for SNR.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -191,15 +190,15 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
num_train_timesteps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.0120,
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2"] = "scaled_linear",
beta_schedule: str = "scaled_linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
prediction_type: str = "epsilon",
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
timestep_spacing: Literal["linspace", "leading", "trailing"] = "leading",
timestep_spacing: str = "leading",
rescale_betas_zero_snr: bool = False,
snr_shift_scale: float = 3.0,
):
@@ -209,15 +208,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float64,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float64) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -247,7 +238,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
beta_prod_t = 1 - alpha_prod_t
@@ -274,11 +265,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def set_timesteps(
self,
num_inference_steps: int,
device: Optional[Union[str, torch.device]] = None,
) -> None:
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -330,7 +317,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
sample: torch.Tensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator: Optional[torch.Generator] = None,
generator=None,
variance_noise: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[DDIMSchedulerOutput, Tuple]:
@@ -341,7 +328,7 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
@@ -500,5 +487,5 @@ class CogVideoXDDIMScheduler(SchedulerMixin, ConfigMixin):
velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return velocity
def __len__(self) -> int:
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -49,7 +49,7 @@ class DDIMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -63,8 +63,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -51,7 +51,7 @@ class DDIMParallelSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -65,8 +65,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -48,7 +48,7 @@ class DDPMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -62,8 +62,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:
@@ -192,12 +192,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
variance_type: Literal[
"fixed_small",
"fixed_small_log",
"fixed_large",
"fixed_large_log",
"learned",
"learned_range",
"fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"
] = "fixed_small",
clip_sample: bool = True,
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
@@ -215,15 +210,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -350,14 +337,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
t: int,
predicted_variance: Optional[torch.Tensor] = None,
variance_type: Optional[
Literal[
"fixed_small",
"fixed_small_log",
"fixed_large",
"fixed_large_log",
"learned",
"learned_range",
]
Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"]
] = None,
) -> torch.Tensor:
"""
@@ -492,10 +472,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
prev_t = self.previous_timestep(t)
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
"learned",
"learned_range",
]:
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
@@ -544,10 +521,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
if t > 0:
device = model_output.device
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=device,
dtype=model_output.dtype,
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
if self.variance_type == "fixed_small_log":
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise

View File

@@ -50,7 +50,7 @@ class DDPMParallelSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -64,8 +64,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:
@@ -202,12 +202,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
beta_schedule: Literal["linear", "scaled_linear", "squaredcos_cap_v2", "sigmoid"] = "linear",
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
variance_type: Literal[
"fixed_small",
"fixed_small_log",
"fixed_large",
"fixed_large_log",
"learned",
"learned_range",
"fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"
] = "fixed_small",
clip_sample: bool = True,
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "epsilon",
@@ -225,15 +220,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = (
torch.linspace(
beta_start**0.5,
beta_end**0.5,
num_train_timesteps,
dtype=torch.float32,
)
** 2
)
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -363,14 +350,7 @@ class DDPMParallelScheduler(SchedulerMixin, ConfigMixin):
t: int,
predicted_variance: Optional[torch.Tensor] = None,
variance_type: Optional[
Literal[
"fixed_small",
"fixed_small_log",
"fixed_large",
"fixed_large_log",
"learned",
"learned_range",
]
Literal["fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned", "learned_range"]
] = None,
) -> torch.Tensor:
"""

View File

@@ -34,7 +34,7 @@ if is_scipy_available():
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -48,8 +48,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -52,7 +52,7 @@ class DDIMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -66,8 +66,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -34,7 +34,7 @@ if is_scipy_available():
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -48,8 +48,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -34,7 +34,7 @@ if is_scipy_available():
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -48,8 +48,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -117,7 +117,7 @@ class BrownianTreeNoiseSampler:
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -131,8 +131,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -50,8 +50,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -51,7 +51,7 @@ class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -65,8 +65,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -54,7 +54,7 @@ class EulerDiscreteSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -68,8 +68,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -51,7 +51,7 @@ class HeunDiscreteSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -65,8 +65,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -52,7 +52,7 @@ class KDPM2AncestralDiscreteSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -66,8 +66,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -51,7 +51,7 @@ class KDPM2DiscreteSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -65,8 +65,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -53,7 +53,7 @@ class LCMSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -67,8 +67,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -49,7 +49,7 @@ class LMSDiscreteSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -63,8 +63,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -28,7 +28,7 @@ from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, Schedul
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -42,8 +42,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -47,7 +47,7 @@ class RePaintSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -61,8 +61,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -35,7 +35,7 @@ if is_scipy_available():
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -49,8 +49,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -52,7 +52,7 @@ class TCDSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -66,8 +66,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -48,7 +48,7 @@ class UnCLIPSchedulerOutput(BaseOutput):
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -62,8 +62,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -34,7 +34,7 @@ if is_scipy_available():
def betas_for_alpha_bar(
num_diffusion_timesteps: int,
max_beta: float = 0.999,
alpha_transform_type: Literal["cosine", "exp", "laplace"] = "cosine",
alpha_transform_type: Literal["cosine", "exp"] = "cosine",
) -> torch.Tensor:
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
@@ -48,8 +48,8 @@ def betas_for_alpha_bar(
The number of betas to produce.
max_beta (`float`, defaults to `0.999`):
The maximum beta to use; use values lower than 1 to avoid numerical instability.
alpha_transform_type (`str`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine`, `exp`, or `laplace`.
alpha_transform_type (`"cosine"` or `"exp"`, defaults to `"cosine"`):
The type of noise schedule for `alpha_bar`. Choose from `cosine` or `exp`.
Returns:
`torch.Tensor`:

View File

@@ -130,6 +130,7 @@ from .loading_utils import get_module_from_name, get_submodule_by_name, load_ima
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
apply_lora_scale,
check_peft_version,
delete_adapter_layers,
get_adapter_name,

View File

@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import functools
import importlib
from typing import Optional
@@ -275,6 +276,59 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_scale(adapter_name, get_module_weight(weight, module_name))
def apply_lora_scale(kwargs_name: str = "joint_attention_kwargs"):
"""
Decorator to automatically handle LoRA layer scaling/unscaling in forward methods.
This decorator extracts the `lora_scale` from the specified kwargs parameter, applies scaling before the forward
pass, and ensures unscaling happens after, even if an exception occurs.
Args:
kwargs_name (`str`, defaults to `"joint_attention_kwargs"`):
The name of the keyword argument that contains the LoRA scale. Common values include
"joint_attention_kwargs", "attention_kwargs", "cross_attention_kwargs", etc.
"""
def decorator(forward_fn):
@functools.wraps(forward_fn)
def wrapper(self, *args, **kwargs):
from . import USE_PEFT_BACKEND
lora_scale = 1.0
attention_kwargs = kwargs.get(kwargs_name)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
kwargs[kwargs_name] = attention_kwargs
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
if (
not USE_PEFT_BACKEND
and attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
f"Passing `scale` via `{kwargs_name}` when not using the PEFT backend is ineffective."
)
# Apply LoRA scaling if using PEFT backend
if USE_PEFT_BACKEND:
scale_lora_layers(self, lora_scale)
try:
# Execute the forward pass
result = forward_fn(self, *args, **kwargs)
return result
finally:
# Always unscale, even if forward pass raises an exception
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
return wrapper
return decorator
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.