Compare commits

..

59 Commits

Author SHA1 Message Date
yiyixuxu
d34c4e8caf update the description of StableDiffusionXLDenoiseLoopWrapper 2025-06-20 07:38:21 +02:00
yiyixuxu
b46b7c8b31 add to method to modular loader, copied from DiffusionPipeline, not tested yet 2025-06-20 07:25:20 +02:00
yiyixuxu
fc9168f429 add block mappings to modular_diffusers.stable_diffusion_xl.__init__ 2025-06-20 07:24:14 +02:00
yiyixuxu
31a31ca1c5 rename modular_pipeline_block_mappings.py to modular_block_mapping 2025-06-20 07:23:14 +02:00
yiyixuxu
8423652b35 updatee modular_pipeline.from_pretrained, modular_repo ->pretrained_model_name_or_path 2025-06-19 05:30:18 +02:00
yiyixuxu
de631947cc up 2025-06-19 04:45:20 +02:00
yiyixuxu
58e9565719 update doc format for kwargs_type 2025-06-19 02:24:51 +02:00
yiyixuxu
cb6d5fed19 refator based on dhruv's feedbacks 2025-06-18 10:11:22 +02:00
yiyixuxu
f16e9c7807 add 2025-06-10 23:10:17 +02:00
yiyixuxu
87f63d424a modular node! 2025-05-22 11:50:36 +02:00
yiyixuxu
29de29f02c add node_utils 2025-05-21 22:31:10 +02:00
yiyixuxu
72e1b74638 solve merge conflict: manually add back the remote code change to modular_pipeline 2025-05-20 20:26:51 +02:00
yiyixuxu
3471f2fb75 merge part1 2025-05-20 18:53:04 +02:00
yiyixuxu
d136ae36c8 update input for loop blocks, do not need to include intermediate 2025-05-20 18:11:05 +02:00
yiyixuxu
1b89ac144c prepare_latents_img2img pipeline method -> function, maybe do the same for others? 2025-05-20 18:10:06 +02:00
yiyixuxu
eb9415031a add a to-do for modular loader 2025-05-20 18:08:28 +02:00
yiyixuxu
de6ab6b49d fix import in block mapping 2025-05-20 18:07:58 +02:00
yiyixuxu
4968edc5dc remove the duplicated components_manager file I forgot to deletee 2025-05-20 18:07:27 +02:00
Dhruv Nair
808dff09cb [WIP] Modular Diffusers support custom code/pipeline blocks (#11539)
* update

* update
2025-05-20 15:12:51 +05:30
yiyixuxu
61dac3bbe4 up 2025-05-19 22:39:32 +02:00
yiyixuxu
73ab5725c2 update components manager 2025-05-18 19:09:01 +02:00
yiyixuxu
163341d3dd refactor modular loader: 1. load only load (pretrained components only if not specific names) 2. update acceept create spec 3. move the updte _componeent_spec logic outside register_components to each method that create/update the component: __init__/update/load 2025-05-18 18:58:26 +02:00
yiyixuxu
d0fbf745e6 refactor component spec: replace create/create_from_pretrained/create_from_config to just create and load method 2025-05-18 18:52:12 +02:00
yiyixuxu
27c1158b23 add a to-do for guider cconfig mixin 2025-05-18 18:50:03 +02:00
yiyixuxu
96ce6744fe after_denoise -> decoders 2025-05-15 00:45:45 +02:00
yiyixuxu
8ad14a52cb make generator intermediates (it is mutable) 2025-05-13 23:25:56 +02:00
yiyixuxu
a7fb2d2a22 remove the output step 2025-05-13 22:15:54 +02:00
yiyixuxu
a0deefb606 fix more 2025-05-13 20:51:21 +02:00
yiyixuxu
e2491af650 fix import 2025-05-13 20:42:57 +02:00
yiyixuxu
506a8ea09c fix imports 2025-05-13 04:36:06 +02:00
yiyixuxu
58358c2d00 decode block, if skip decoding do not need to update latent 2025-05-13 01:57:47 +02:00
yiyixuxu
5cde77f915 make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables 2025-05-13 01:52:51 +02:00
yiyixuxu
522e827625 move block mappings to its own file 2025-05-12 01:17:45 +02:00
yiyixuxu
144eae4e0b add block state will also make sure modifed intermediates_inputs will be updated 2025-05-12 01:16:42 +02:00
yiyixuxu
796453cad1 add notes 2025-05-12 01:14:43 +02:00
yiyixuxu
153ae34ff6 update __init__ 2025-05-10 03:50:47 +02:00
yiyixuxu
0acb5e1460 made a modular_pipelines folder! 2025-05-10 03:50:31 +02:00
yiyixuxu
462429b687 remove modular reelated change from pipelines folder 2025-05-10 03:50:10 +02:00
yiyixuxu
cf01aaeb49 update imports on guiders 2025-05-10 03:49:30 +02:00
yiyixuxu
2017ae5624 fix auto denoise so all tests pass 2025-05-09 08:19:24 +02:00
yiyixuxu
2b361a2413 fix get_execusion blocks with loopsequential 2025-05-09 08:17:10 +02:00
yiyixuxu
c677d528e4 change warning to debug 2025-05-09 08:16:24 +02:00
yiyixuxu
0f0618ff2b refactor the denoiseestep using LoopSequential! also add a new file for denoise step 2025-05-08 11:28:52 +02:00
yiyixuxu
d89631fc50 update input formating, consider kwarggs_type inputs with no name, e/g *_controlnet_kwargs 2025-05-08 11:27:17 +02:00
yiyixuxu
16b6583fa8 allow input_fields as input & update message 2025-05-08 11:25:31 +02:00
yiyixuxu
f552773572 remove controlnet union denoise step, refactor & reuse controlnet denoisee step to accept aditional contrlnet kwargs 2025-05-06 10:00:14 +02:00
yiyixuxu
dc4dbfe107 reefactor pipeline/block states so that it can dynamically accept kwargs 2025-05-06 09:58:44 +02:00
yiyixuxu
43ac1ff7e7 refactor controlnet union 2025-05-04 22:17:25 +02:00
yiyixuxu
efd70b7838 seperate controlnet step into input + denoise 2025-05-03 20:22:05 +02:00
yiyixuxu
7ca860c24b rename pipeline -> components, data -> block_state 2025-05-03 01:32:59 +02:00
yiyixuxu
7b86fcea31 remove lora step and ip-adapter step -> no longer needed 2025-05-02 11:31:25 +02:00
yiyixuxu
c8b5d56412 make loader optional 2025-05-02 00:46:31 +02:00
YiYi Xu
ce642e92da Merge branch 'modular-diffusers' into modular-refactor 2025-04-30 17:56:51 -10:00
YiYi Xu
6d5beefe29 [modular diffusers] introducing ModularLoader (#11462)
* cfg; slg; pag; sdxl without controlnet

---------

Co-authored-by: Aryan <aryan@huggingface.co>
2025-04-30 11:17:20 -10:00
Aryan
b863bdd6ca Modular Diffusers Guiders (#11311)
* cfg; slg; pag; sdxl without controlnet

* support sdxl controlnet

* support controlnet union

* update

* update

* cfg zero*

* use unwrap_module for torch compiled modules

* remove guider kwargs

* remove commented code

* remove old guider

* fix slg bug

* remove debug print

* autoguidance

* smoothed energy guidance

* add note about seg

* tangential cfg

* cfg plus plus

* support cfgpp in ddim

* apply review suggestions

* refactor

* rename enable/disable

* remove cfg++ for now

* rename do_classifier_free_guidance->prepare_unconditional_embeds

* remove unused
2025-04-26 03:42:42 +05:30
yiyixuxu
d143851309 move methods to blocks 2025-04-12 11:46:25 +02:00
yiyixuxu
9ad1470d48 up 2025-04-11 18:29:21 +02:00
yiyixuxu
bf99ab2f55 up 2025-04-09 20:36:45 +02:00
yiyixuxu
ee842839ef add componentspec and configspec 2025-04-09 01:40:02 +02:00
40 changed files with 11087 additions and 6497 deletions

View File

@@ -34,10 +34,12 @@ from .utils import (
_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"modular_pipelines": [],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
@@ -130,12 +132,26 @@ except OptionalDependencyNotAvailable:
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"AutoGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"SkipLayerGuidance",
"SmoothedEnergyGuidance",
"TangentialClassifierFreeGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"HookRegistry",
"PyramidAttentionBroadcastConfig",
"LayerSkipConfig",
"SmoothedEnergyGuidanceConfig",
"apply_faster_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
@@ -239,13 +255,21 @@ else:
"KarrasVePipeline",
"LDMPipeline",
"LDMSuperResolutionPipeline",
"ModularPipeline",
"PNDMPipeline",
"RePaintPipeline",
"ScoreSdeVePipeline",
"StableDiffusionMixin",
]
)
_import_structure["modular_pipelines"].extend(
[
"ModularLoader",
"ModularPipeline",
"ModularPipelineBlocks",
"ComponentSpec",
"ComponentsManager",
]
)
_import_structure["quantizers"] = ["DiffusersQuantizer"]
_import_structure["schedulers"].extend(
[
@@ -494,12 +518,10 @@ else:
"StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLModularPipeline",
"StableDiffusionXLPAGImg2ImgPipeline",
"StableDiffusionXLPAGInpaintPipeline",
"StableDiffusionXLPAGPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLAutoPipeline",
"StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline",
"StableVideoDiffusionPipeline",
@@ -526,6 +548,24 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_objects # noqa F403
_import_structure["utils.dummy_torch_and_transformers_objects"] = [
name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_")
]
else:
_import_structure["modular_pipelines"].extend(
[
"StableDiffusionXLAutoPipeline",
"StableDiffusionXLModularLoader",
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
raise OptionalDependencyNotAvailable()
@@ -731,10 +771,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
AutoGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
SkipLayerGuidance,
SmoothedEnergyGuidance,
TangentialClassifierFreeGuidance,
)
from .hooks import (
FasterCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
SmoothedEnergyGuidanceConfig,
apply_layer_skip,
apply_faster_cache,
apply_pyramid_attention_broadcast,
)
@@ -837,12 +889,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KarrasVePipeline,
LDMPipeline,
LDMSuperResolutionPipeline,
ModularPipeline,
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
StableDiffusionMixin,
)
from .modular_pipelines import (
ModularLoader,
ModularPipeline,
ModularPipelineBlocks,
ComponentSpec,
ComponentsManager,
)
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
@@ -1070,12 +1128,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularPipeline,
StableDiffusionXLPAGImg2ImgPipeline,
StableDiffusionXLPAGInpaintPipeline,
StableDiffusionXLPAGPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
StableUnCLIPImg2ImgPipeline,
StableUnCLIPPipeline,
StableVideoDiffusionPipeline,
@@ -1100,7 +1156,16 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)
try:
if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipelines import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()

View File

@@ -1,745 +0,0 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from .models.attention_processor import (
Attention,
AttentionProcessor,
PAGCFGIdentitySelfAttnProcessor2_0,
PAGIdentitySelfAttnProcessor2_0,
)
from .utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
class CFGGuider:
"""
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
"""
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0 and not self._disable_guidance
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def guidance_scale(self):
return self._guidance_scale
@property
def batch_size(self):
return self._batch_size
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
# a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead
disable_guidance = guider_kwargs.get("disable_guidance", False)
guidance_scale = guider_kwargs.get("guidance_scale", None)
if guidance_scale is None:
raise ValueError("guidance_scale is not provided in guider_kwargs")
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is not provided in guider_kwargs")
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
self._disable_guidance = disable_guidance
def reset_guider(self, pipeline):
pass
def maybe_update_guider(self, pipeline, timestep):
pass
def maybe_update_input(self, pipeline, cond_input):
pass
def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.
Args:
cond (torch.Tensor): The conditional input tensor to process.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 2:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size :]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")
def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Classifier-Free Guidance (CFG).
Args:
cond (torch.Tensor): The conditional input tensor to check.
Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
return cond_tensor.shape[0] == self.batch_size * 2
def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Prepare the input for CFG.
Args:
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
single tensor or a list of tensors. It must have the same length as `cond_input`.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
"""
# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
return cond_input
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)
if isinstance(cond_input, (list, tuple)):
if not self.do_classifier_free_guidance:
return cond_input
if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
return prepared_input
elif isinstance(cond_input, torch.Tensor):
if not self.do_classifier_free_guidance:
return cond_input
else:
return torch.cat([negative_cond_input, cond_input], dim=0)
else:
raise ValueError(f"Unsupported input type: {type(cond_input)}")
def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_classifier_free_guidance:
return model_output
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
class PAGGuider:
"""
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
"""
def __init__(
self,
pag_applied_layers: Union[str, List[str]],
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0(),
),
):
r"""
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
Args:
pag_applied_layers (`str` or `List[str]`):
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
PAG is to be applied. A few ways of expected usage are as follows:
- Single layers specified as - "blocks.{layer_index}"
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
- Multiple layers as a block name - "mid"
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
pag_attn_processors:
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
attention processor is for PAG with CFG disabled (unconditional only).
"""
if not isinstance(pag_applied_layers, list):
pag_applied_layers = [pag_applied_layers]
if pag_attn_processors is not None:
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
raise ValueError("Expected a tuple of two attention processors")
for i in range(len(pag_applied_layers)):
if not isinstance(pag_applied_layers[i], str):
raise ValueError(
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
)
self.pag_applied_layers = pag_applied_layers
self._pag_attn_processors = pag_attn_processors
def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance):
r"""
Set the attention processor for the PAG layers.
"""
pag_attn_processors = self._pag_attn_processors
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
def is_self_attn(module: nn.Module) -> bool:
r"""
Check if the module is self-attention module based on its name.
"""
return isinstance(module, Attention) and not module.is_cross_attention
def is_fake_integral_match(layer_id, name):
layer_id = layer_id.split(".")[-1]
name = name.split(".")[-1]
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
for layer_id in pag_applied_layers:
# for each PAG layer input, we find corresponding self-attention layers in the unet model
target_modules = []
for name, module in model.named_modules():
# Identify the following simple cases:
# (1) Self Attention layer existing
# (2) Whether the module name matches pag layer id even partially
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
if (
is_self_attn(module)
and re.search(layer_id, name) is not None
and not is_fake_integral_match(layer_id, name)
):
logger.debug(f"Applying PAG to layer: {name}")
target_modules.append(module)
if len(target_modules) == 0:
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
for module in target_modules:
module.processor = pag_attn_proc
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1 and not self._disable_guidance
@property
def do_perturbed_attention_guidance(self):
return self._pag_scale > 0 and not self._disable_guidance
@property
def do_pag_adaptive_scaling(self):
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def batch_size(self):
return self._batch_size
@property
def pag_scale(self):
return self._pag_scale
@property
def pag_adaptive_scale(self):
return self._pag_adaptive_scale
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
pag_scale = guider_kwargs.get("pag_scale", 3.0)
pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is a required argument for PAGGuider")
guidance_scale = guider_kwargs.get("guidance_scale", None)
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
disable_guidance = guider_kwargs.get("disable_guidance", False)
if guidance_scale is None:
raise ValueError("guidance_scale is a required argument for PAGGuider")
self._pag_scale = pag_scale
self._pag_adaptive_scale = pag_adaptive_scale
self._guidance_scale = guidance_scale
self._disable_guidance = disable_guidance
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None:
pipeline.original_attn_proc = pipeline.unet.attn_processors
self._set_pag_attn_processor(
model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer,
pag_applied_layers=self.pag_applied_layers,
do_classifier_free_guidance=self.do_classifier_free_guidance,
)
def reset_guider(self, pipeline):
if (
self.do_perturbed_attention_guidance
and hasattr(pipeline, "original_attn_proc")
and pipeline.original_attn_proc is not None
):
pipeline.unet.set_attn_processor(pipeline.original_attn_proc)
pipeline.original_attn_proc = None
def maybe_update_guider(self, pipeline, timestep):
pass
def maybe_update_input(self, pipeline, cond_input):
pass
def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Perturbed Attention Guidance (PAG).
Args:
cond (torch.Tensor): The conditional input tensor to check.
Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
return cond_tensor.shape[0] == self.batch_size * 3
def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.
Args:
cond (torch.Tensor): The conditional input tensor to process.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 3:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size : self.batch_size * 2]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")
def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]:
"""
Prepare the input for CFG.
Args:
cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
The negative conditional input. It can be a single tensor or a list of tensors. It must have the same
length as `cond_input`.
Returns:
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input.
"""
# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance:
return cond_input
if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)
if isinstance(cond_input, (list, tuple)):
if not self.do_perturbed_attention_guidance:
return cond_input
if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
cond = torch.cat([cond] * 2, dim=0)
if self.do_classifier_free_guidance:
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
else:
prepared_input.append(cond)
return prepared_input
elif isinstance(cond_input, torch.Tensor):
if not self.do_perturbed_attention_guidance:
return cond_input
cond_input = torch.cat([cond_input] * 2, dim=0)
if self.do_classifier_free_guidance:
return torch.cat([negative_cond_input, cond_input], dim=0)
else:
return cond_input
else:
raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}")
def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_perturbed_attention_guidance:
return model_output
if self.do_pag_adaptive_scaling:
pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0)
else:
pag_scale = self._pag_scale
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3)
noise_pred = (
noise_pred_uncond
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+ pag_scale * (noise_pred_text - noise_pred_perturb)
)
else:
noise_pred_text, noise_pred_perturb = model_output.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
class APGGuider:
"""
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
"""
def normalized_guidance(
self,
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
norm_threshold: float = 0.0,
eta: float = 1.0,
):
"""
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion
Models](https://arxiv.org/pdf/2410.02416)
"""
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided
@property
def adaptive_projected_guidance_momentum(self):
return self._adaptive_projected_guidance_momentum
@property
def adaptive_projected_guidance_rescale_factor(self):
return self._adaptive_projected_guidance_rescale_factor
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0 and not self._disable_guidance
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def guidance_scale(self):
return self._guidance_scale
@property
def batch_size(self):
return self._batch_size
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
disable_guidance = guider_kwargs.get("disable_guidance", False)
guidance_scale = guider_kwargs.get("guidance_scale", None)
if guidance_scale is None:
raise ValueError("guidance_scale is not provided in guider_kwargs")
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
"adaptive_projected_guidance_rescale_factor", 15.0
)
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is not provided in guider_kwargs")
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
self._disable_guidance = disable_guidance
if adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
else:
self.momentum_buffer = None
self.scheduler = pipeline.scheduler
def reset_guider(self, pipeline):
pass
def maybe_update_guider(self, pipeline, timestep):
pass
def maybe_update_input(self, pipeline, cond_input):
pass
def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.
Args:
cond (torch.Tensor): The conditional input tensor to process.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 2:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size :]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")
def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Classifier-Free Guidance (CFG).
Args:
cond (torch.Tensor): The conditional input tensor to check.
Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
return cond_tensor.shape[0] == self.batch_size * 2
def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Prepare the input for CFG.
Args:
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
single tensor or a list of tensors. It must have the same length as `cond_input`.
Returns:
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
"""
# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
return cond_input
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)
if isinstance(cond_input, (list, tuple)):
if not self.do_classifier_free_guidance:
return cond_input
if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
return prepared_input
elif isinstance(cond_input, torch.Tensor):
if not self.do_classifier_free_guidance:
return cond_input
else:
return torch.cat([negative_cond_input, cond_input], dim=0)
else:
raise ValueError(f"Unsupported input type: {type(cond_input)}")
def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_classifier_free_guidance:
return model_output
if latents is None:
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")
sigma = self.scheduler.sigmas[self.scheduler.step_index]
noise_pred = latents - sigma * model_output
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = self.normalized_guidance(
noise_pred_text,
noise_pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.adaptive_projected_guidance_rescale_factor,
)
noise_pred = (latents - noise_pred) / sigma
if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred

View File

@@ -0,0 +1,29 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from ..utils import is_torch_available
if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .auto_guidance import AutoGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
GuiderType = Union[AdaptiveProjectedGuidance, AutoGuidance, ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance, SkipLayerGuidance, SmoothedEnergyGuidance, TangentialClassifierFreeGuidance]

View File

@@ -0,0 +1,184 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AdaptiveProjectedGuidance(BaseGuidance):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_apg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_apg_enabled():
num_conditions += 1
return num_conditions
def _is_apg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0
def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update
return pred

View File

@@ -0,0 +1,177 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class AutoGuidance(BaseGuidance):
"""
AutoGuidance: https://huggingface.co/papers/2406.02507
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
auto_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided.
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
dropout (`float`, *optional*):
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
dropout: Optional[float] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.auto_guidance_layers = auto_guidance_layers
self.auto_guidance_config = auto_guidance_config
self.dropout = dropout
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if auto_guidance_layers is None and auto_guidance_config is None:
raise ValueError(
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
)
if auto_guidance_layers is not None and auto_guidance_config is not None:
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
if (dropout is None and auto_guidance_layers is not None) or (dropout is not None and auto_guidance_layers is None):
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
if auto_guidance_layers is not None:
if isinstance(auto_guidance_layers, int):
auto_guidance_layers = [auto_guidance_layers]
if not isinstance(auto_guidance_layers, list):
raise ValueError(
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
)
auto_guidance_config = [LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers]
if isinstance(auto_guidance_config, LayerSkipConfig):
auto_guidance_config = [auto_guidance_config]
if not isinstance(auto_guidance_config, list):
raise ValueError(
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
)
self.auto_guidance_config = auto_guidance_config
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_ag_enabled() and self.is_unconditional:
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_ag_enabled() and self.is_unconditional:
for name in self._auto_guidance_hook_names:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_ag_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_ag_enabled():
num_conditions += 1
return num_conditions
def _is_ag_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -0,0 +1,132 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeGuidance(BaseGuidance):
"""
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
inference. This allows the model to tradeoff between generation quality and sample diversity.
The original paper proposes scaling and shifting the conditional distribution based on the difference between
conditional and unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self, guidance_scale: float = 7.5, guidance_rescale: float = 0.0, use_original_formulation: bool = False, start: float = 0.0, stop: float = 1.0
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled():
pred = pred_cond
else:
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close

View File

@@ -0,0 +1,148 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class ClassifierFreeZeroStarGuidance(BaseGuidance):
"""
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
quality of generated images.
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
zero_init_steps (`int`, defaults to `1`):
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
zero_init_steps: int = 1,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.zero_init_steps = zero_init_steps
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if self._step < self.zero_init_steps:
pred = torch.zeros_like(pred_cond)
elif not self._is_cfg_enabled():
pred = pred_cond
else:
pred_cond_flat = pred_cond.flatten(1)
pred_uncond_flat = pred_uncond.flatten(1)
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
pred_uncond = pred_uncond * alpha
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
cond_dtype = cond.dtype
cond = cond.float()
uncond = uncond.float()
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
scale = dot_product / squared_norm
return scale.to(dtype=cond_dtype)

View File

@@ -0,0 +1,215 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
import torch
from ..utils import get_logger
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
logger = get_logger(__name__) # pylint: disable=invalid-name
class BaseGuidance:
r"""Base class providing the skeleton for implementing guidance techniques."""
_input_predictions = None
_identifier_key = "__guidance_identifier__"
def __init__(self, start: float = 0.0, stop: float = 1.0):
self._start = start
self._stop = stop
self._step: int = None
self._num_inference_steps: int = None
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(
f"Expected `start` to be between 0.0 and 1.0, but got {start}."
)
if not (start <= stop <= 1.0):
raise ValueError(
f"Expected `stop` to be between {start} and 1.0, but got {stop}."
)
if self._input_predictions is None or not isinstance(self._input_predictions, list):
raise ValueError(
"`_input_predictions` must be a list of required prediction names for the guidance technique."
)
def disable(self):
self._enabled = False
def enable(self):
self._enabled = True
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
self._step = step
self._num_inference_steps = num_inference_steps
self._timestep = timestep
self._count_prepared = 0
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
"""
Set the input fields for the guidance technique. The input fields are used to specify the names of the
returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is
obtained from the values of the provided keyword arguments to this method.
Args:
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
which is used to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
be the conditional data identifier and the second element must be the unconditional data identifier
or None.
Example:
```
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
BaseGuidance.set_input_fields(
latents="latents",
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
)
```
"""
for key, value in kwargs.items():
is_string = isinstance(value, str)
is_tuple_of_str_with_len_2 = isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
if not (is_string or is_tuple_of_str_with_len_2):
raise ValueError(
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
)
self._input_fields = kwargs
def prepare_models(self, denoiser: torch.nn.Module) -> None:
"""
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
subclasses to implement specific model preparation logic.
"""
self._count_prepared += 1
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
"""
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden in
subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
modifications made during `prepare_models`.
"""
pass
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.")
if len(data) != self.num_conditions:
raise ValueError(
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
)
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
return self.forward(**forward_inputs)
def forward(self, *args, **kwargs) -> Any:
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
@property
def is_conditional(self) -> bool:
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
@property
def is_unconditional(self) -> bool:
return not self.is_conditional
@property
def num_conditions(self) -> int:
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
@classmethod
def _prepare_batch(cls, input_fields: Dict[str, Union[str, Tuple[str, str]]], data: "BlockState", tuple_index: int, identifier: str) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of
the `BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once
it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2,
which is used to look up the required data provided for preparation.
If a string is provided, it will be used as the conditional data (or unconditional if used with
a guidance method that requires it). If a tuple of length 2 is provided, the first element must
be the conditional data identifier and the second element must be the unconditional data identifier
or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
r"""
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Args:
noise_cfg (`torch.Tensor`):
The predicted noise tensor for the guided diffusion process.
noise_pred_text (`torch.Tensor`):
The predicted noise tensor for the text-guided diffusion process.
guidance_rescale (`float`, *optional*, defaults to 0.0):
A rescale factor applied to the noise predictions.
Returns:
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg

View File

@@ -0,0 +1,251 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
from ..hooks import HookRegistry, LayerSkipConfig
from ..hooks.layer_skip import _apply_layer_skip_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SkipLayerGuidance(BaseGuidance):
"""
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
batch of data, apart from the conditional and unconditional batches already used in CFG
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
based on the difference between conditional without skipping and conditional with skipping predictions.
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
version of the model for the conditional prediction).
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
generation quality in video diffusion models.
Additional reading:
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
skip_layer_guidance_scale (`float`, defaults to `2.8`):
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
skip_layer_guidance_start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which skip layer guidance starts.
skip_layer_guidance_stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which skip layer guidance stops.
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
def __init__(
self,
guidance_scale: float = 7.5,
skip_layer_guidance_scale: float = 2.8,
skip_layer_guidance_start: float = 0.01,
skip_layer_guidance_stop: float = 0.2,
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.skip_layer_guidance_scale = skip_layer_guidance_scale
self.skip_layer_guidance_start = skip_layer_guidance_start
self.skip_layer_guidance_stop = skip_layer_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= skip_layer_guidance_start < 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
)
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
raise ValueError(
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
)
if skip_layer_guidance_layers is None and skip_layer_config is None:
raise ValueError(
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
)
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
if skip_layer_guidance_layers is not None:
if isinstance(skip_layer_guidance_layers, int):
skip_layer_guidance_layers = [skip_layer_guidance_layers]
if not isinstance(skip_layer_guidance_layers, list):
raise ValueError(
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
)
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
if isinstance(skip_layer_config, LayerSkipConfig):
skip_layer_config = [skip_layer_config]
if not isinstance(skip_layer_config, list):
raise ValueError(
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
)
self.skip_layer_config = skip_layer_config
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
self._count_prepared += 1
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
_apply_layer_skip_hook(denoiser, config, name=name)
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_skip: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_slg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_cond_skip
pred = pred + self.skip_layer_guidance_scale * shift
elif not self._is_slg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_skip = pred_cond - pred_cond_skip
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_slg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_slg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -0,0 +1,244 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
from ..hooks import HookRegistry
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class SmoothedEnergyGuidance(BaseGuidance):
"""
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
in the future without warning or guarantee of reproducibility. This implementation assumes:
- Generated images are square (height == width)
- The model does not combine different modalities together (e.g., text and image latent streams are
not combined together such as Flux)
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
seg_guidance_scale (`float`, defaults to `3.0`):
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
values, but it may also lead to overexposure and saturation.
seg_blur_sigma (`float`, defaults to `9999999.0`):
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
The threshold above which the blur is considered infinite.
seg_guidance_start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
seg_guidance_stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
seg_guidance_layers (`int` or `List[int]`, *optional*):
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If not
provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
3.5 Medium.
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or a list of
`SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.01`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `0.2`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
def __init__(
self,
guidance_scale: float = 7.5,
seg_guidance_scale: float = 2.8,
seg_blur_sigma: float = 9999999.0,
seg_blur_threshold_inf: float = 9999.0,
seg_guidance_start: float = 0.0,
seg_guidance_stop: float = 1.0,
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.seg_guidance_scale = seg_guidance_scale
self.seg_blur_sigma = seg_blur_sigma
self.seg_blur_threshold_inf = seg_blur_threshold_inf
self.seg_guidance_start = seg_guidance_start
self.seg_guidance_stop = seg_guidance_stop
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
if not (0.0 <= seg_guidance_start < 1.0):
raise ValueError(
f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}."
)
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
raise ValueError(
f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}."
)
if seg_guidance_layers is None and seg_guidance_config is None:
raise ValueError(
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
)
if seg_guidance_layers is not None and seg_guidance_config is not None:
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
if seg_guidance_layers is not None:
if isinstance(seg_guidance_layers, int):
seg_guidance_layers = [seg_guidance_layers]
if not isinstance(seg_guidance_layers, list):
raise ValueError(
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
)
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
seg_guidance_config = [seg_guidance_config]
if not isinstance(seg_guidance_config, list):
raise ValueError(
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
)
self.seg_guidance_config = seg_guidance_config
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
def prepare_models(self, denoiser: torch.nn.Module) -> None:
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
def cleanup_models(self, denoiser: torch.nn.Module):
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
# Remove the hooks after inference
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = ["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(
self,
pred_cond: torch.Tensor,
pred_uncond: Optional[torch.Tensor] = None,
pred_cond_seg: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pred = None
if not self._is_cfg_enabled() and not self._is_seg_enabled():
pred = pred_cond
elif not self._is_cfg_enabled():
shift = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_cond_seg
pred = pred + self.seg_guidance_scale * shift
elif not self._is_seg_enabled():
shift = pred_cond - pred_uncond
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift
else:
shift = pred_cond - pred_uncond
shift_seg = pred_cond - pred_cond_seg
pred = pred_cond if self.use_original_formulation else pred_uncond
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1 or self._count_prepared == 3
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
if self._is_seg_enabled():
num_conditions += 1
return num_conditions
def _is_cfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def _is_seg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
is_within_range = skip_start_step < self._step < skip_stop_step
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
return is_within_range and not is_zero

View File

@@ -0,0 +1,137 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
from .guider_utils import BaseGuidance, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class TangentialClassifierFreeGuidance(BaseGuidance):
"""
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
def __init__(
self,
guidance_scale: float = 7.5,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__(start, stop)
self.guidance_scale = guidance_scale
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None
if not self._is_tcfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return pred, {}
@property
def is_conditional(self) -> bool:
return self._num_outputs_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_tcfg_enabled():
num_conditions += 1
return num_conditions
def _is_tcfg_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def normalized_guidance(pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False) -> torch.Tensor:
cond_dtype = pred_cond.dtype
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
preds = preds.flatten(2)
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
Vh_modified = Vh.clone()
Vh_modified[:, 1] = 0
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
pred = pred_cond if use_original_formulation else pred_uncond
shift = pred_cond - pred_uncond
pred = pred + guidance_scale * shift
return pred

View File

@@ -5,5 +5,7 @@ if is_torch_available():
from .faster_cache import FasterCacheConfig, apply_faster_cache
from .group_offloading import apply_group_offloading
from .hooks import HookRegistry, ModelHook
from .layer_skip import LayerSkipConfig, apply_layer_skip
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig

View File

@@ -0,0 +1,43 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
from ..models.attention import FeedForward, LuminaFeedForward
from ..models.attention_processor import Attention, MochiAttention
_ATTENTION_CLASSES = (Attention, MochiAttention)
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
{
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
}
)
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
for submodule_name, submodule in module.named_modules():
if submodule_name == fqn:
return submodule
return None

View File

@@ -0,0 +1,271 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Type
from ..models.attention import BasicTransformerBlock
from ..models.attention_processor import AttnProcessor2_0
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor, CogView4TransformerBlock
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
from ..models.transformers.transformer_hunyuan_video import (
HunyuanVideoSingleTransformerBlock,
HunyuanVideoTokenReplaceSingleTransformerBlock,
HunyuanVideoTokenReplaceTransformerBlock,
HunyuanVideoTransformerBlock,
)
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock
@dataclass
class AttentionProcessorMetadata:
skip_processor_output_fn: Callable[[Any], Any]
@dataclass
class TransformerBlockMetadata:
skip_block_output_fn: Callable[[Any], Any]
return_hidden_states_index: int = None
return_encoder_hidden_states_index: int = None
class AttentionProcessorRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
class TransformerBlockRegistry:
_registry = {}
@classmethod
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
cls._registry[model_class] = metadata
@classmethod
def get(cls, model_class: Type) -> TransformerBlockMetadata:
if model_class not in cls._registry:
raise ValueError(f"Model class {model_class} not registered.")
return cls._registry[model_class]
def _register_attention_processors_metadata():
# AttnProcessor2_0
AttentionProcessorRegistry.register(
model_class=AttnProcessor2_0,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
),
)
# CogView4AttnProcessor
AttentionProcessorRegistry.register(
model_class=CogView4AttnProcessor,
metadata=AttentionProcessorMetadata(
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
),
)
def _register_transformer_blocks_metadata():
# BasicTransformerBlock
TransformerBlockRegistry.register(
model_class=BasicTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_BasicTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# CogVideoX
TransformerBlockRegistry.register(
model_class=CogVideoXBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogVideoXBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# CogView4
TransformerBlockRegistry.register(
model_class=CogView4TransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_CogView4TransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Flux
TransformerBlockRegistry.register(
model_class=FluxTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
TransformerBlockRegistry.register(
model_class=FluxSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_FluxSingleTransformerBlock,
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# HunyuanVideo
TransformerBlockRegistry.register(
model_class=HunyuanVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
TransformerBlockRegistry.register(
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# LTXVideo
TransformerBlockRegistry.register(
model_class=LTXVideoTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_LTXVideoTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# Mochi
TransformerBlockRegistry.register(
model_class=MochiTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_MochiTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=1,
),
)
# Wan
TransformerBlockRegistry.register(
model_class=WanTransformerBlock,
metadata=TransformerBlockMetadata(
skip_block_output_fn=_skip_block_output_fn_WanTransformerBlock,
return_hidden_states_index=0,
return_encoder_hidden_states_index=None,
),
)
# fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___ret___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
return hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return hidden_states, encoder_hidden_states
def _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states(self, *args, **kwargs):
hidden_states = kwargs.get("hidden_states", None)
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if hidden_states is None and len(args) > 0:
hidden_states = args[0]
if encoder_hidden_states is None and len(args) > 1:
encoder_hidden_states = args[1]
return encoder_hidden_states, hidden_states
_skip_block_output_fn_BasicTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_CogVideoXBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_CogView4TransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_FluxTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_FluxSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___encoder_hidden_states___hidden_states
_skip_block_output_fn_HunyuanVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_HunyuanVideoTokenReplaceSingleTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_LTXVideoTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
_skip_block_output_fn_MochiTransformerBlock = _skip_block_output_fn___hidden_states_0___encoder_hidden_states_1___ret___hidden_states___encoder_hidden_states
_skip_block_output_fn_WanTransformerBlock = _skip_block_output_fn___hidden_states_0___ret___hidden_states
# fmt: on
_register_attention_processors_metadata()
_register_transformer_blocks_metadata()

View File

@@ -0,0 +1,231 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Callable, List, Optional
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _FEEDFORWARD_CLASSES, _get_submodule_from_fqn
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_LAYER_SKIP_HOOK = "layer_skip_hook"
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
# either remove or make it serializable
@dataclass
class LayerSkipConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`.
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
skip_attention (`bool`, defaults to `True`):
Whether to skip attention blocks.
skip_ff (`bool`, defaults to `True`):
Whether to skip feed-forward blocks.
skip_attention_scores (`bool`, defaults to `False`):
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
projections as the output of scaled dot product attention.
dropout (`float`, defaults to `1.0`):
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
skipped layers are fully retained, which is equivalent to not skipping any layers.
"""
indices: List[int]
fqn: str = "auto"
skip_attention: bool = True
skip_attention_scores: bool = False
skip_ff: bool = True
dropout: float = 1.0
def __post_init__(self):
if not (0 <= self.dropout <= 1):
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.nn.functional.scaled_dot_product_attention:
value = kwargs.get("value", None)
if value is None:
value = args[2]
return value
return func(*args, **kwargs)
class AttentionProcessorSkipHook(ModelHook):
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
self.skip_processor_output_fn = skip_processor_output_fn
self.skip_attention_scores = skip_attention_scores
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if self.skip_attention_scores:
if not math.isclose(self.dropout, 1.0):
raise ValueError(
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
)
with AttentionScoreSkipFunctionMode():
output = self.fn_ref.original_forward(*args, **kwargs)
else:
if math.isclose(self.dropout, 1.0):
output = self.skip_processor_output_fn(module, *args, **kwargs)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class FeedForwardSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
output = kwargs.get("hidden_states", None)
if output is None:
output = kwargs.get("x", None)
if output is None and len(args) > 0:
output = args[0]
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
class TransformerBlockSkipHook(ModelHook):
def __init__(self, dropout: float):
super().__init__()
self.dropout = dropout
def initialize_hook(self, module):
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
if math.isclose(self.dropout, 1.0):
output = self._metadata.skip_block_output_fn(module, *args, **kwargs)
else:
output = self.fn_ref.original_forward(*args, **kwargs)
output = torch.nn.functional.dropout(output, p=self.dropout)
return output
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
r"""
Apply layer skipping to internal layers of a transformer.
Args:
module (`torch.nn.Module`):
The transformer model to which the layer skip hook should be applied.
config (`LayerSkipConfig`):
The configuration for the layer skip hook.
Example:
```python
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
>>> apply_layer_skip_hook(transformer, config)
```
"""
_apply_layer_skip_hook(module, config)
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
name = name or _LAYER_SKIP_HOOK
if config.skip_attention and config.skip_attention_scores:
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
raise ValueError("Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0.")
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
raise ValueError(
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
)
if len(config.indices) == 0:
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
if config.skip_attention and config.skip_ff:
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = TransformerBlockSkipHook(config.dropout)
registry.register_hook(hook, name)
elif config.skip_attention or config.skip_attention_scores:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
registry.register_hook(hook, name)
if config.skip_ff:
for submodule_name, submodule in block.named_modules():
if isinstance(submodule, _FEEDFORWARD_CLASSES):
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
registry = HookRegistry.check_if_exists_or_initialize(submodule)
hook = FeedForwardSkipHook(config.dropout)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)

View File

@@ -0,0 +1,158 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from ..utils import get_logger
from ._common import _ATTENTION_CLASSES, _get_submodule_from_fqn
from .hooks import HookRegistry, ModelHook
logger = get_logger(__name__) # pylint: disable=invalid-name
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
@dataclass
class SmoothedEnergyGuidanceConfig:
r"""
Configuration for skipping internal transformer blocks when executing a transformer model.
Args:
indices (`List[int]`):
The indices of the layer to skip. This is typically the first layer in the transformer block.
fqn (`str`, defaults to `"auto"`):
The fully qualified name identifying the stack of transformer blocks. Typically, this is
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
For automatic detection, set this to `"auto"`.
"auto" only works on DiT models. For UNet models, you must provide the correct fqn.
_query_proj_identifiers (`List[str]`, defaults to `None`):
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`.
If `None`, `to_q` is used by default.
"""
indices: List[int]
fqn: str = "auto"
_query_proj_identifiers: List[str] = None
class SmoothedEnergyGuidanceHook(ModelHook):
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
super().__init__()
self.blur_sigma = blur_sigma
self.blur_threshold_inf = blur_threshold_inf
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
return smoothed_output
def _apply_smoothed_energy_guidance_hook(module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None) -> None:
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
if config.fqn == "auto":
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
if hasattr(module, identifier):
config.fqn = identifier
break
else:
raise ValueError(
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
)
if config._query_proj_identifiers is None:
config._query_proj_identifiers = ["to_q"]
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
blocks_found = False
for i, block in enumerate(transformer_blocks):
if i not in config.indices:
continue
blocks_found = True
for submodule_name, submodule in block.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
continue
for identifier in config._query_proj_identifiers:
query_proj = getattr(submodule, identifier, None)
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
continue
logger.debug(
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
)
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
hook = SmoothedEnergyGuidanceHook(blur_sigma)
registry.register_hook(hook, name)
if not blocks_found:
raise ValueError(
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
)
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
"""
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian
blur. However, some models use joint text-visual token attention for which this may not be suitable. Additionally,
this implementation also assumes that the visual tokens come from a square image/video. In practice, despite
these assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results
for Smoothed Energy Guidance.
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified
in the future without warning or guarantee of reproducibility.
"""
assert query.ndim == 3
is_inf = sigma > sigma_threshold_inf
batch_size, seq_len, embed_dim = query.shape
seq_len_sqrt = int(math.sqrt(seq_len))
num_square_tokens = seq_len_sqrt * seq_len_sqrt
query_slice = query[:, :num_square_tokens, :]
query_slice = query_slice.permute(0, 2, 1)
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
if is_inf:
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
kernel_size_half = (kernel_size - 1) / 2
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()
kernel1d = kernel1d.to(query)
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
query_slice = F.pad(query_slice, padding, mode="reflect")
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
else:
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
query_slice = query_slice.permute(0, 2, 1)
query[:, :num_square_tokens, :] = query_slice.clone()
return query

View File

@@ -0,0 +1,84 @@
from typing import TYPE_CHECKING
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
# These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {}
_import_structure = {}
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else:
_import_structure["modular_pipeline"] = [
"ModularPipelineBlocks",
"ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks",
"SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks",
"ModularLoader",
"PipelineState",
"BlockState",
]
_import_structure["modular_pipeline_utils"] = [
"ComponentSpec",
"ConfigSpec",
"InputParam",
"OutputParam",
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoPipeline", "StableDiffusionXLModularLoader"]
_import_structure["components_manager"] = ["ComponentsManager"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
LoopSequentialPipelineBlocks,
ModularLoader,
ModularPipelineBlocks,
ModularPipeline,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
from .modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
from .stable_diffusion_xl import (
StableDiffusionXLAutoPipeline,
StableDiffusionXLModularLoader,
)
from .components_manager import ComponentsManager
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -26,6 +26,10 @@ from ..utils import (
logging,
)
from ..models.modeling_utils import ModelMixin
from .modular_pipeline_utils import ComponentSpec
import uuid
if is_accelerate_available():
@@ -229,54 +233,209 @@ class AutoOffloadStrategy:
return hooks_to_offload
class ComponentsManager:
def __init__(self):
self.components = OrderedDict()
self.added_time = OrderedDict() # Store when components were added
self.added_time = OrderedDict() # Store when components were added
self.collections = OrderedDict() # collection_name -> set of component_names
self.model_hooks = None
self._auto_offload_enabled = False
def add(self, name, component):
if name in self.components:
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
self.components[name] = component
self.added_time[name] = time.time()
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
def remove(self, name):
if name not in self.components:
logger.warning(f"Component '{name}' not found in ComponentsManager")
return
self.components.pop(name)
self.added_time.pop(name)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
# YiYi TODO: looking into improving the search pattern
def get(self, names: Union[str, List[str]]):
def _lookup_ids(self, name=None, collection=None, load_id=None, components: OrderedDict = None):
"""
Get components by name with simple pattern matching.
Lookup component_ids by name, collection, or load_id.
"""
if components is None:
components = self.components
if name:
ids_by_name = set()
for component_id, component in components.items():
comp_name = self._id_to_name(component_id)
if comp_name == name:
ids_by_name.add(component_id)
else:
ids_by_name = set(components.keys())
if collection:
ids_by_collection = set()
for component_id, component in components.items():
if component_id in self.collections[collection]:
ids_by_collection.add(component_id)
else:
ids_by_collection = set(components.keys())
if load_id:
ids_by_load_id = set()
for name, component in components.items():
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id:
ids_by_load_id.add(name)
else:
ids_by_load_id = set(components.keys())
ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id)
return ids
@staticmethod
def _id_to_name(component_id: str):
return "_".join(component_id.split("_")[:-1])
def add(self, name, component, collection: Optional[str] = None):
component_id = f"{name}_{uuid.uuid4()}"
# check for duplicated components
for comp_id, comp in self.components.items():
if comp == component:
comp_name = self._id_to_name(comp_id)
if comp_name == name:
logger.warning(
f"component '{name}' already exists as '{comp_id}'"
)
component_id = comp_id
break
else:
logger.warning(
f"Adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'"
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
# check for duplicated load_id and warn (we do not delete for you)
if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null":
components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id)
components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id]
if components_with_same_load_id:
existing = ", ".join(components_with_same_load_id)
logger.warning(
f"Adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. "
f"To remove a duplicate, call `components_manager.remove('<component_id>')`."
)
# add component to components manager
self.components[component_id] = component
self.added_time[component_id] = time.time()
if collection:
if collection not in self.collections:
self.collections[collection] = set()
if not component_id in self.collections[collection]:
comp_ids_in_collection = self._lookup_ids(name=name, collection=collection)
for comp_id in comp_ids_in_collection:
logger.info(f"Removing existing {name} from collection '{collection}': {comp_id}")
self.remove(comp_id)
self.collections[collection].add(component_id)
logger.info(f"Added component '{name}' in collection '{collection}': {component_id}")
else:
logger.info(f"Added component '{name}' as '{component_id}'")
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
return component_id
def remove(self, component_id: str = None):
if component_id not in self.components:
logger.warning(f"Component '{component_id}' not found in ComponentsManager")
return
component = self.components.pop(component_id)
self.added_time.pop(component_id)
for collection in self.collections:
if component_id in self.collections[collection]:
self.collections[collection].remove(component_id)
if self._auto_offload_enabled:
self.enable_auto_cpu_offload(self._auto_offload_device)
else:
if isinstance(component, torch.nn.Module):
component.to("cpu")
del component
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get(self, names: Union[str, List[str]] = None, collection: Optional[str] = None, load_id: Optional[str] = None,
as_name_component_tuples: bool = False):
"""
Select components by name with simple pattern matching.
Args:
names: Component name(s) or pattern(s)
Patterns:
- "unet" : exact match
- "!unet" : everything except exact match "unet"
- "base_*" : everything starting with "base_"
- "!base_*" : everything NOT starting with "base_"
- "*unet*" : anything containing "unet"
- "!*unet*" : anything NOT containing "unet"
- "refiner|vae|unet" : anything containing any of these terms
- "!refiner|vae|unet" : anything NOT containing any of these terms
- "unet" : match any component with base name "unet" (e.g., unet_123abc)
- "!unet" : everything except components with base name "unet"
- "unet*" : anything with base name starting with "unet"
- "!unet*" : anything with base name NOT starting with "unet"
- "*unet*" : anything with base name containing "unet"
- "!*unet*" : anything with base name NOT containing "unet"
- "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet"
- "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet"
- "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae"
collection: Optional collection to filter by
load_id: Optional load_id to filter by
as_name_component_tuples: If True, returns a list of (name, component) tuples using base names
instead of a dictionary with component IDs as keys
Returns:
Single component if names is str and matches one component,
dict of components if names matches multiple components or is a list
Dictionary mapping component IDs to components,
or list of (base_name, component) tuples if as_name_component_tuples=True
"""
selected_ids = self._lookup_ids(collection=collection, load_id=load_id)
components = {k: self.components[k] for k in selected_ids}
# Helper to extract base name from component_id
def get_base_name(component_id):
parts = component_id.split('_')
# If the last part looks like a UUID, remove it
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return component_id
if names is None:
if as_name_component_tuples:
return [(get_base_name(comp_id), comp) for comp_id, comp in components.items()]
else:
return components
# Create mapping from component_id to base_name for all components
base_names = {comp_id: get_base_name(comp_id) for comp_id in components.keys()}
def matches_pattern(component_id, pattern, exact_match=False):
"""
Helper function to check if a component matches a pattern based on its base name.
Args:
component_id: The component ID to check
pattern: The pattern to match against
exact_match: If True, only exact matches to base_name are considered
"""
base_name = base_names[component_id]
# Exact match with base name
if exact_match:
return pattern == base_name
# Prefix match (ends with *)
elif pattern.endswith('*'):
prefix = pattern[:-1]
return base_name.startswith(prefix)
# Contains match (starts with *)
elif pattern.startswith('*'):
search = pattern[1:-1] if pattern.endswith('*') else pattern[1:]
return search in base_name
# Exact match (no wildcards)
else:
return pattern == base_name
if isinstance(names, str):
# Check if this is a "not" pattern
is_not_pattern = names.startswith('!')
@@ -286,33 +445,45 @@ class ComponentsManager:
# Handle OR patterns (containing |)
if '|' in names:
terms = names.split('|')
matches = {
name: comp for name, comp in self.components.items()
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
else:
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
matches = {}
for comp_id, comp in components.items():
# For OR patterns with exact names (no wildcards), we do exact matching on base names
exact_match = all(not (term.startswith('*') or term.endswith('*')) for term in terms)
# Check if any of the terms match this component
should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms)
# Flip the decision if this is a NOT pattern
if is_not_pattern:
should_include = not should_include
if should_include:
matches[comp_id] = comp
log_msg = "NOT " if is_not_pattern else ""
match_type = "exactly matching" if exact_match else "matching any of patterns"
logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}")
# Exact match
elif names in self.components:
# Try exact match with a base name
elif any(names == base_name for base_name in base_names.values()):
# Find all components with this base name
matches = {
comp_id: comp for comp_id, comp in components.items()
if (base_names[comp_id] == names) != is_not_pattern
}
if is_not_pattern:
matches = {
name: comp for name, comp in self.components.items()
if name != names
}
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting component: {names}")
return self.components[names]
logger.info(f"Getting components with base name '{names}': {list(matches.keys())}")
# Prefix match (ends with *)
elif names.endswith('*'):
prefix = names[:-1]
matches = {
name: comp for name, comp in self.components.items()
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
comp_id: comp for comp_id, comp in components.items()
if base_names[comp_id].startswith(prefix) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
@@ -323,31 +494,46 @@ class ComponentsManager:
elif names.startswith('*'):
search = names[1:-1] if names.endswith('*') else names[1:]
matches = {
name: comp for name, comp in self.components.items()
if (search in name) != is_not_pattern # Flip condition if not pattern
comp_id: comp for comp_id, comp in components.items()
if (search in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
# Substring match (no wildcards, but not an exact component name)
elif any(names in base_name for base_name in base_names.values()):
matches = {
comp_id: comp for comp_id, comp in components.items()
if (names in base_names[comp_id]) != is_not_pattern
}
if is_not_pattern:
logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}")
else:
logger.info(f"Getting components containing '{names}': {list(matches.keys())}")
else:
raise ValueError(f"Component '{names}' not found in ComponentsManager")
raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager")
if not matches:
raise ValueError(f"No components found matching pattern '{names}'")
return matches if len(matches) > 1 else next(iter(matches.values()))
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in matches.items()]
else:
return matches
elif isinstance(names, list):
results = {}
for name in names:
result = self.get(name)
if isinstance(result, dict):
results.update(result)
else:
results[name] = result
logger.info(f"Getting multiple components: {list(results.keys())}")
return results
result = self.get(name, collection, load_id, as_name_component_tuples=False)
results.update(result)
if as_name_component_tuples:
return [(base_names[comp_id], comp) for comp_id, comp in results.items()]
else:
return results
else:
raise ValueError(f"Invalid type for names: {type(names)}")
@@ -391,11 +577,12 @@ class ComponentsManager:
self.model_hooks = None
self._auto_offload_enabled = False
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
# YiYi TODO: add quantization info
def get_model_info(self, component_id: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
"""Get comprehensive information about a component.
Args:
name: Name of the component to get info for
component_id: Name of the component to get info for
fields: Optional field(s) to return. Can be a string for single field or list of fields.
If None, returns all fields.
@@ -404,23 +591,32 @@ class ComponentsManager:
If fields is specified, returns only those fields.
If a single field is requested as string, returns just that field's value.
"""
if name not in self.components:
raise ValueError(f"Component '{name}' not found in ComponentsManager")
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
component = self.components[name]
component = self.components[component_id]
# Build complete info dict first
info = {
"model_id": name,
"added_time": self.added_time[name],
"model_id": component_id,
"added_time": self.added_time[component_id],
"collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) or None,
}
# Additional info for torch.nn.Module components
if isinstance(component, torch.nn.Module):
# Check for hook information
has_hook = hasattr(component, "_hf_hook")
execution_device = None
if has_hook and hasattr(component._hf_hook, "execution_device"):
execution_device = component._hf_hook.execution_device
info.update({
"class_name": component.__class__.__name__,
"size_gb": get_memory_footprint(component) / (1024**3),
"adapters": None, # Default to None
"has_hook": has_hook,
"execution_device": execution_device,
})
# Get adapters if applicable
@@ -454,12 +650,64 @@ class ComponentsManager:
return info
def __repr__(self):
# Helper to get simple name without UUID
def get_simple_name(name):
# Extract the base name by splitting on underscore and taking first part
# This assumes names are in format "name_uuid"
parts = name.split('_')
# If we have at least 2 parts and the last part looks like a UUID, remove it
if len(parts) > 1 and len(parts[-1]) >= 8 and '-' in parts[-1]:
return '_'.join(parts[:-1])
return name
# Extract load_id if available
def get_load_id(component):
if hasattr(component, "_diffusers_load_id"):
return component._diffusers_load_id
return "N/A"
# Format device info compactly
def format_device(component, info):
if not info["has_hook"]:
return str(getattr(component, 'device', 'N/A'))
else:
device = str(getattr(component, 'device', 'N/A'))
exec_device = str(info['execution_device'] or 'N/A')
return f"{device}({exec_device})"
# Get all simple names to calculate width
simple_names = [get_simple_name(id) for id in self.components.keys()]
# Get max length of load_ids for models
load_ids = [
get_load_id(component)
for component in self.components.values()
if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id")
]
max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15
# Get all collections for each component
component_collections = {}
for name in self.components.keys():
component_collections[name] = []
for coll, comps in self.collections.items():
if name in comps:
component_collections[name].append(coll)
if not component_collections[name]:
component_collections[name] = ["N/A"]
# Find the maximum collection name length
all_collections = [coll for colls in component_collections.values() for coll in colls]
max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10
col_widths = {
"id": max(15, max(len(id) for id in self.components.keys())),
"name": max(15, max(len(name) for name in simple_names)),
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
"device": 10,
"device": 15, # Reduced since using more compact format
"dtype": 15,
"size": 10,
"load_id": max_load_id_len,
"collection": max_collection_len
}
# Create the header lines
@@ -476,17 +724,33 @@ class ComponentsManager:
if models:
output += "Models:\n" + dash_line
# Column headers
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | "
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | "
output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n"
output += dash_line
# Model entries
for name, component in models.items():
info = self.get_model_info(name)
device = str(getattr(component, "device", "N/A"))
simple_name = get_simple_name(name)
device_str = format_device(component, info)
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
load_id = get_load_id(component)
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {info['class_name']:<{col_widths['class']}} | "
output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | "
output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | "
output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | "
output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n"
output += dash_line
# Other components section
@@ -495,12 +759,24 @@ class ComponentsManager:
output += "\n"
output += "Other Components:\n" + dash_line
# Column headers for other components
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
output += f"{'Name':<{col_widths['name']}} | {'Class':<{col_widths['class']}} | Collection\n"
output += dash_line
# Other component entries
for name, component in others.items():
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
info = self.get_model_info(name)
simple_name = get_simple_name(name)
# Print first collection on the main line
first_collection = component_collections[name][0] if component_collections[name] else "N/A"
output += f"{simple_name:<{col_widths['name']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n"
# Print additional collections on separate lines if they exist
for i in range(1, len(component_collections[name])):
collection = component_collections[name][i]
output += f"{'':<{col_widths['name']}} | {'':<{col_widths['class']}} | {collection}\n"
output += dash_line
# Add additional component info
@@ -508,7 +784,8 @@ class ComponentsManager:
for name in self.components:
info = self.get_model_info(name)
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
output += f"\n{name}:\n"
simple_name = get_simple_name(name)
output += f"\n{simple_name}:\n"
if info.get("adapters") is not None:
output += f" Adapters: {info['adapters']}\n"
if info.get("ip_adapter"):
@@ -517,7 +794,7 @@ class ComponentsManager:
return output
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
def from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
"""
Load components from a pretrained model and add them to the manager.
@@ -527,17 +804,12 @@ class ComponentsManager:
If provided, components will be named as "{prefix}_{component_name}"
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
"""
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
subfolder = kwargs.pop("subfolder", None)
# YiYi TODO: extend AutoModel to support non-diffusers models
if subfolder:
from ..models import AutoModel
component = AutoModel.from_pretrained(pretrained_model_name_or_path, subfolder=subfolder, **kwargs)
component_name = f"{prefix}_{subfolder}" if prefix else subfolder
if component_name not in self.components:
self.add(component_name, component)
else:
@@ -546,6 +818,59 @@ class ComponentsManager:
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
else:
from ..pipelines.pipeline_utils import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
for name, component in pipe.components.items():
if component is None:
continue
# Add prefix if specified
component_name = f"{prefix}_{name}" if prefix else name
if component_name not in self.components:
self.add(component_name, component)
else:
logger.warning(
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
f"1. remove the existing component with remove('{component_name}')\n"
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
)
def get_one(self, component_id: Optional[str] = None, name: Optional[str] = None, collection: Optional[str] = None, load_id: Optional[str] = None) -> Any:
"""
Get a single component by name. Raises an error if multiple components match or none are found.
Args:
name: Component name or pattern
collection: Optional collection to filter by
load_id: Optional load_id to filter by
Returns:
A single component
Raises:
ValueError: If no components match or multiple components match
"""
# if component_id is provided, return the component
if component_id is not None and (name is not None or collection is not None or load_id is not None):
raise ValueError(" if component_id is provided, name, collection, and load_id must be None")
elif component_id is not None:
if component_id not in self.components:
raise ValueError(f"Component '{component_id}' not found in ComponentsManager")
return self.components[component_id]
results = self.get(name, collection, load_id)
if not results:
raise ValueError(f"No components found matching '{name}'")
if len(results) > 1:
raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}")
return next(iter(results.values()))
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
"""Summarizes a dictionary by finding common prefixes that share the same value.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,616 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import inspect
from dataclasses import dataclass, asdict, field, fields
from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal
from ..utils.import_utils import is_torch_available
from ..configuration_utils import FrozenDict, ConfigMixin
from collections import OrderedDict
if is_torch_available():
import torch
class InsertableOrderedDict(OrderedDict):
def insert(self, key, value, index):
items = list(self.items())
# Remove key if it already exists to avoid duplicates
items = [(k, v) for k, v in items if k != key]
# Insert at the specified index
items.insert(index, (key, value))
# Clear and update self
self.clear()
self.update(items)
# Return self for method chaining
return self
# YiYi TODO:
# 1. validate the dataclass fields
# 2. add a validator for create_* methods, make sure they are valid inputs to pass to from_pretrained()
@dataclass
class ComponentSpec:
"""Specification for a pipeline component.
A component can be created in two ways:
1. From scratch using __init__ with a config dict
2. using `from_pretrained`
Attributes:
name: Name of the component
type_hint: Type of the component (e.g. UNet2DConditionModel)
description: Optional description of the component
config: Optional config dict for __init__ creation
repo: Optional repo path for from_pretrained creation
subfolder: Optional subfolder in repo
variant: Optional variant in repo
revision: Optional revision in repo
default_creation_method: Preferred creation method - "from_config" or "from_pretrained"
"""
name: Optional[str] = None
type_hint: Optional[Type] = None
description: Optional[str] = None
config: Optional[FrozenDict[str, Any]] = None
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
variant: Optional[str] = field(default=None, metadata={"loading": True})
revision: Optional[str] = field(default=None, metadata={"loading": True})
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"
def __hash__(self):
"""Make ComponentSpec hashable, using load_id as the hash value."""
return hash((self.name, self.load_id, self.default_creation_method))
def __eq__(self, other):
"""Compare ComponentSpec objects based on name and load_id."""
if not isinstance(other, ComponentSpec):
return False
return (self.name == other.name and
self.load_id == other.load_id and
self.default_creation_method == other.default_creation_method)
@classmethod
def from_component(cls, name: str, component: Any) -> Any:
"""Create a ComponentSpec from a Component created by `create` or `load` method."""
if not hasattr(component, "_diffusers_load_id"):
raise ValueError("Component is not created by `create` or `load` method")
# throw a error if component is created with `create` method but not a subclass of ConfigMixin
# YiYi TODO: remove this check if we remove support for non configmixin in `create()` method
if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin):
raise ValueError(
"We currently only support creating ComponentSpec from a component with "
"created with `ComponentSpec.load` method"
"or created with `ComponentSpec.create` and a subclass of ConfigMixin"
)
type_hint = component.__class__
default_creation_method = "from_config" if component._diffusers_load_id == "null" else "from_pretrained"
if isinstance(component, ConfigMixin):
config = component.config
else:
config = None
load_spec = cls.decode_load_id(component._diffusers_load_id)
return cls(name=name, type_hint=type_hint, config=config, default_creation_method=default_creation_method, **load_spec)
@classmethod
def loading_fields(cls) -> List[str]:
"""
Return the names of all loadingrelated fields
(i.e. those whose field.metadata["loading"] is True).
"""
return [f.name for f in fields(cls) if f.metadata.get("loading", False)]
@property
def load_id(self) -> str:
"""
Unique identifier for this spec's pretrained load,
composed of repo|subfolder|variant|revision (no empty segments).
"""
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
@classmethod
def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:
"""
Decode a load_id string back into a dictionary of loading fields and values.
Args:
load_id: The load_id string to decode, format: "repo|subfolder|variant|revision"
where None values are represented as "null"
Returns:
Dict mapping loading field names to their values. e.g.
{
"repo": "path/to/repo",
"subfolder": "subfolder",
"variant": "variant",
"revision": "revision"
}
If a segment value is "null", it's replaced with None.
Returns None if load_id is "null" (indicating component not created with `load` method).
"""
# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
if load_id == "null":
return result
# Split the load_id
parts = load_id.split("|")
# Map parts to loading fields by position
for i, part in enumerate(parts):
if i < len(loading_fields):
# Convert "null" string back to None
result[loading_fields[i]] = None if part == "null" else part
return result
# YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin)
# otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component)
# the config info is lost in the process
# remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method
def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any:
"""Create component using from_config with config."""
if self.type_hint is None or not isinstance(self.type_hint, type):
raise ValueError(
f"`type_hint` is required when using from_config creation method."
)
config = config or self.config or {}
if issubclass(self.type_hint, ConfigMixin):
component = self.type_hint.from_config(config, **kwargs)
else:
signature_params = inspect.signature(self.type_hint.__init__).parameters
init_kwargs = {}
for k, v in config.items():
if k in signature_params:
init_kwargs[k] = v
for k, v in kwargs.items():
if k in signature_params:
init_kwargs[k] = v
component = self.type_hint(**init_kwargs)
component._diffusers_load_id = "null"
if hasattr(component, "config"):
self.config = component.config
return component
# YiYi TODO: add guard for type of model, if it is supported by from_pretrained
def load(self, **kwargs) -> Any:
"""Load component using from_pretrained."""
# select loading fields from kwargs passed from user: e.g. repo, subfolder, variant, revision, note the list could change
passed_loading_kwargs = {key: kwargs.pop(key) for key in self.loading_fields() if key in kwargs}
# merge loading field value in the spec with user passed values to create load_kwargs
load_kwargs = {key: passed_loading_kwargs.get(key, getattr(self, key)) for key in self.loading_fields()}
# repo is a required argument for from_pretrained, a.k.a. pretrained_model_name_or_path
repo = load_kwargs.pop("repo", None)
if repo is None:
raise ValueError(f"`repo` info is required when using `load` method (you can directly set it in `repo` field of the ComponentSpec or pass it as an argument)")
if self.type_hint is None:
try:
from diffusers import AutoModel
component = AutoModel.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
# update type_hint if AutoModel load successfully
self.type_hint = component.__class__
else:
try:
component = self.type_hint.from_pretrained(repo, **load_kwargs, **kwargs)
except Exception as e:
raise ValueError(f"Unable to load {self.name} using load method: {e}")
self.repo = repo
for k, v in load_kwargs.items():
setattr(self, k, v)
component._diffusers_load_id = self.load_id
return component
@dataclass
class ConfigSpec:
"""Specification for a pipeline configuration parameter."""
name: str
default: Any
description: Optional[str] = None
# YiYi Notes: both inputs and intermediates_inputs are InputParam objects
# however some fields are not relevant for intermediates_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
# default is not used for intermediates_inputs, we only use default from inputs, so it is ignored if it is set for intermediates_inputs
# -> should we use different class for inputs and intermediates_inputs?
@dataclass
class InputParam:
"""Specification for an input parameter."""
name: str = None
type_hint: Any = None
default: Any = None
required: bool = False
description: str = ""
kwargs_type: str = None # YiYi Notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
def __repr__(self):
return f"<{self.name}: {self.type_hint.__name__ if hasattr(self.type_hint, '__name__') else str(self.type_hint)}>"
def format_inputs_short(inputs):
"""
Format input parameters into a string representation, with required params first followed by optional ones.
Args:
inputs: List of input parameters with 'required' and 'name' attributes, and 'default' for optional params
Returns:
str: Formatted string of input parameters
Example:
>>> inputs = [
... InputParam(name="prompt", required=True),
... InputParam(name="image", required=True),
... InputParam(name="guidance_scale", required=False, default=7.5),
... InputParam(name="num_inference_steps", required=False, default=50)
... ]
>>> format_inputs_short(inputs)
'prompt, image, guidance_scale=7.5, num_inference_steps=50'
"""
required_inputs = [param for param in inputs if param.required]
optional_inputs = [param for param in inputs if not param.required]
required_str = ", ".join(param.name for param in required_inputs)
optional_str = ", ".join(f"{param.name}={param.default}" for param in optional_inputs)
inputs_str = required_str
if optional_str:
inputs_str = f"{inputs_str}, {optional_str}" if required_str else optional_str
return inputs_str
def format_intermediates_short(intermediates_inputs, required_intermediates_inputs, intermediates_outputs):
"""
Formats intermediate inputs and outputs of a block into a string representation.
Args:
intermediates_inputs: List of intermediate input parameters
required_intermediates_inputs: List of required intermediate input names
intermediates_outputs: List of intermediate output parameters
Returns:
str: Formatted string like:
Intermediates:
- inputs: Required(latents), dtype
- modified: latents # variables that appear in both inputs and outputs
- outputs: images # new outputs only
"""
# Handle inputs
input_parts = []
for inp in intermediates_inputs:
if inp.name in required_intermediates_inputs:
input_parts.append(f"Required({inp.name})")
else:
if inp.name is None and inp.kwargs_type is not None:
inp_name = "*_" + inp.kwargs_type
else:
inp_name = inp.name
input_parts.append(inp_name)
# Handle modified variables (appear in both inputs and outputs)
inputs_set = {inp.name for inp in intermediates_inputs}
modified_parts = []
new_output_parts = []
for out in intermediates_outputs:
if out.name in inputs_set:
modified_parts.append(out.name)
else:
new_output_parts.append(out.name)
result = []
if input_parts:
result.append(f" - inputs: {', '.join(input_parts)}")
if modified_parts:
result.append(f" - modified: {', '.join(modified_parts)}")
if new_output_parts:
result.append(f" - outputs: {', '.join(new_output_parts)}")
return "\n".join(result) if result else " (none)"
def format_params(params, header="Args", indent_level=4, max_line_length=115):
"""Format a list of InputParam or OutputParam objects into a readable string representation.
Args:
params: List of InputParam or OutputParam objects to format
header: Header text to use (e.g. "Args" or "Returns")
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all parameters
"""
if not params:
return ""
base_indent = " " * indent_level
param_indent = " " * (indent_level + 4)
desc_indent = " " * (indent_level + 8)
formatted_params = []
def get_type_str(type_hint):
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
return f"Union[{', '.join(types)}]"
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
def wrap_text(text, indent, max_length):
"""Wrap text while preserving markdown links and maintaining indentation."""
words = text.split()
lines = []
current_line = []
current_length = 0
for word in words:
word_length = len(word) + (1 if current_line else 0)
if current_line and current_length + word_length > max_length:
lines.append(" ".join(current_line))
current_line = [word]
current_length = len(word)
else:
current_line.append(word)
current_length += word_length
if current_line:
lines.append(" ".join(current_line))
return f"\n{indent}".join(lines)
# Add the header
formatted_params.append(f"{base_indent}{header}:")
for param in params:
# Format parameter name and type
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
# YiYi Notes: remove this line if we remove kwargs_type
name = f'**{param.kwargs_type}' if param.name is None and param.kwargs_type is not None else param.name
param_str = f"{param_indent}{name} (`{type_str}`"
# Add optional tag and default value if parameter is an InputParam and optional
if hasattr(param, "required"):
if not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to {param.default}"
param_str += "):"
# Add description on a new line with additional indentation and wrapping
if param.description:
desc = re.sub(
r'\[(.*?)\]\((https?://[^\s\)]+)\)',
r'[\1](\2)',
param.description
)
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
param_str += f"\n{desc_indent}{wrapped_desc}"
formatted_params.append(param_str)
return "\n\n".join(formatted_params)
def format_input_params(input_params, indent_level=4, max_line_length=115):
"""Format a list of InputParam objects into a readable string representation.
Args:
input_params: List of InputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all input parameters
"""
return format_params(input_params, "Inputs", indent_level, max_line_length)
def format_output_params(output_params, indent_level=4, max_line_length=115):
"""Format a list of OutputParam objects into a readable string representation.
Args:
output_params: List of OutputParam objects to format
indent_level: Number of spaces to indent each parameter line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
Returns:
A formatted string representing all output parameters
"""
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
Args:
components: List of ComponentSpec objects to format
indent_level: Number of spaces to indent each component line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between components (default: True)
Returns:
A formatted string representing all components
"""
if not components:
return ""
base_indent = " " * indent_level
component_indent = " " * (indent_level + 4)
formatted_components = []
# Add the header
formatted_components.append(f"{base_indent}Components:")
if add_empty_lines:
formatted_components.append("")
# Add each component with optional empty lines between them
for i, component in enumerate(components):
# Get type name, handling special cases
type_name = component.type_hint.__name__ if hasattr(component.type_hint, "__name__") else str(component.type_hint)
component_desc = f"{component_indent}{component.name} (`{type_name}`)"
if component.description:
component_desc += f": {component.description}"
# Get the loading fields dynamically
loading_field_values = []
for field_name in component.loading_fields():
field_value = getattr(component, field_name)
if field_value is not None:
loading_field_values.append(f"{field_name}={field_value}")
# Add loading field information if available
if loading_field_values:
component_desc += f" [{', '.join(loading_field_values)}]"
formatted_components.append(component_desc)
# Add an empty line after each component except the last one
if add_empty_lines and i < len(components) - 1:
formatted_components.append("")
return "\n".join(formatted_components)
def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ConfigSpec objects into a readable string representation.
Args:
configs: List of ConfigSpec objects to format
indent_level: Number of spaces to indent each config line (default: 4)
max_line_length: Maximum length for each line before wrapping (default: 115)
add_empty_lines: Whether to add empty lines between configs (default: True)
Returns:
A formatted string representing all configs
"""
if not configs:
return ""
base_indent = " " * indent_level
config_indent = " " * (indent_level + 4)
formatted_configs = []
# Add the header
formatted_configs.append(f"{base_indent}Configs:")
if add_empty_lines:
formatted_configs.append("")
# Add each config with optional empty lines between them
for i, config in enumerate(configs):
config_desc = f"{config_indent}{config.name} (default: {config.default})"
if config.description:
config_desc += f": {config.description}"
formatted_configs.append(config_desc)
# Add an empty line after each config except the last one
if add_empty_lines and i < len(configs) - 1:
formatted_configs.append("")
return "\n".join(formatted_configs)
def make_doc_string(inputs, intermediates_inputs, outputs, description="", class_name=None, expected_components=None, expected_configs=None):
"""
Generates a formatted documentation string describing the pipeline block's parameters and structure.
Args:
inputs: List of input parameters
intermediates_inputs: List of intermediate input parameters
outputs: List of output parameters
description (str, *optional*): Description of the block
class_name (str, *optional*): Name of the class to include in the documentation
expected_components (List[ComponentSpec], *optional*): List of expected components
expected_configs (List[ConfigSpec], *optional*): List of expected configurations
Returns:
str: A formatted string containing information about components, configs, call parameters,
intermediate inputs/outputs, and final outputs.
"""
output = ""
# Add class name if provided
if class_name:
output += f"class {class_name}\n\n"
# Add description
if description:
desc_lines = description.strip().split('\n')
aligned_desc = '\n'.join(' ' + line for line in desc_lines)
output += aligned_desc + "\n\n"
# Add components section if provided
if expected_components and len(expected_components) > 0:
components_str = format_components(expected_components, indent_level=2)
output += components_str + "\n\n"
# Add configs section if provided
if expected_configs and len(expected_configs) > 0:
configs_str = format_configs(expected_configs, indent_level=2)
output += configs_str + "\n\n"
# Add inputs section
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
# Add outputs section
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output

View File

@@ -0,0 +1,519 @@
from ..configuration_utils import ConfigMixin
from .modular_pipeline import SequentialPipelineBlocks, ModularPipelineBlocks
from .modular_pipeline_utils import InputParam, OutputParam
from ..image_processor import PipelineImageInput
from pathlib import Path
import json
import os
from typing import Union, List, Optional, Tuple
import torch
import PIL
import numpy as np
import logging
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS ={
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None
e.g. "prompt_embeds" -> "text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node defination, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node defination {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediates_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if not inp.name in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.blocks.keys())[-1]
outputs = self.blocks.blocks[last_block_name].intermediates_outputs
else:
outputs = self.blocks.intermediates_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components, collection=None):
self.blocks.setup_loader(component_manager=components, collection=collection)
self._components_manager = components
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {
"module": self.mellon_config,
"name_mapping": self.name_mapping
}
# Save the config to file
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, 'r', encoding='utf-8') as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.blocks.loader.update(**params_components)
output = self.blocks.run(**params_run, output=return_output_names)
return output

View File

@@ -0,0 +1,53 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_pipeline_presets"] = ["StableDiffusionXLAutoPipeline"]
_import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"]
_import_structure["encoders"] = ["StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLTextEncoderStep", "StableDiffusionXLAutoVaeEncoderStep"]
_import_structure["decoders"] = ["StableDiffusionXLAutoDecodeStep"]
_import_structure["modular_block_mappings"] = ["TEXT2IMAGE_BLOCKS", "IMAGE2IMAGE_BLOCKS", "INPAINT_BLOCKS", "CONTROLNET_BLOCKS", "CONTROLNET_UNION_BLOCKS", "IP_ADAPTER_BLOCKS", "AUTO_BLOCKS", "SDXL_SUPPORTED_BLOCKS"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_pipeline_presets import StableDiffusionXLAutoPipeline
from .modular_loader import StableDiffusionXLModularLoader
from .encoders import StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoVaeEncoderStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .modular_block_mappings import SDXL_SUPPORTED_BLOCKS, TEXT2IMAGE_BLOCKS, IMAGE2IMAGE_BLOCKS, INPAINT_BLOCKS, CONTROLNET_BLOCKS, CONTROLNET_UNION_BLOCKS, IP_ADAPTER_BLOCKS, AUTO_BLOCKS
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,215 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
import numpy as np
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...configuration_utils import FrozenDict
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from ..modular_pipeline import (
AutoPipelineBlocks,
PipelineBlock,
PipelineState,
SequentialPipelineBlocks,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
return "Step that decodes the denoised latents into images"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents from the denoising step")]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images, can be a PIL.Image.Image, torch.Tensor or a numpy array")]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae with self -> components
@staticmethod
def upcast_vae(components):
dtype = components.vae.dtype
components.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance(
components.vae.decoder.mid_block.attentions[0].processor,
(
AttnProcessor2_0,
XFormersAttnProcessor,
),
)
# if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory
if use_torch_2_0_or_xformers:
components.vae.post_quant_conv.to(dtype)
components.vae.decoder.conv_in.to(dtype)
components.vae.decoder.mid_block.to(dtype)
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if not block_state.output_type == "latent":
latents = block_state.latents
# make sure the VAE is in float32 mode, as it overflows in float16
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
if block_state.needs_upcasting:
self.upcast_vae(components)
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
elif latents.dtype != components.vae.dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
components.vae = components.vae.to(latents.dtype)
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
block_state.has_latents_mean = (
hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None
)
block_state.has_latents_std = (
hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None
)
if block_state.has_latents_mean and block_state.has_latents_std:
block_state.latents_mean = (
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
block_state.latents_std = (
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
)
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
else:
latents = latents / components.vae.config.scaling_factor
block_state.images = components.vae.decode(latents, return_dict=False)[0]
# cast back to fp16 if needed
if block_state.needs_upcasting:
components.vae.to(dtype=torch.float16)
else:
block_state.images = block_state.latents
# apply watermark if available
if hasattr(components, "watermark") and components.watermark is not None:
block_state.images = components.watermark.apply_watermark(block_state.images)
block_state.images = components.image_processor.postprocess(block_state.images, output_type=block_state.output_type)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return "A post-processing step that overlays the mask on the image (inpainting task only).\n" + \
"only needed when you are using the `padding_mask_crop` option when pre-processing the image and mask"
@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[str]:
return [
InputParam("images", required=True, type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images from the decode step"),
InputParam("crops_coords", required=True, type_hint=Tuple[int, int], description="The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can be generated in vae_encode step.")
]
@property
def intermediates_outputs(self) -> List[str]:
return [OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="The generated images with the mask overlayed")]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if block_state.padding_mask_crop is not None and block_state.crops_coords is not None:
block_state.images = [components.image_processor.apply_overlay(block_state.mask_image, block_state.image, i, block_state.crops_coords) for i in block_state.images]
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintDecodeStep(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLDecodeStep, StableDiffusionXLInpaintOverlayMaskStep]
block_names = ["decode", "mask_overlay"]
@property
def description(self):
return "Inpaint decode step that decode the denoised latents into images outputs.\n" + \
"This is a sequential pipeline blocks:\n" + \
" - `StableDiffusionXLDecodeStep` is used to decode the denoised latents into images\n" + \
" - `StableDiffusionXLInpaintOverlayMaskStep` is used to overlay the mask on the image"
class StableDiffusionXLAutoDecodeStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintDecodeStep, StableDiffusionXLDecodeStep]
block_names = ["inpaint", "non-inpaint"]
block_trigger_inputs = ["padding_mask_crop", None]
@property
def description(self):
return "Decode step that decode the denoised latents into images outputs.\n" + \
"This is an auto pipeline block that works for inpainting and non-inpainting tasks.\n" + \
" - `StableDiffusionXLInpaintDecodeStep` (inpaint) is used when `padding_mask_crop` is provided.\n" + \
" - `StableDiffusionXLDecodeStep` (non-inpaint) is used when `padding_mask_crop` is not provided."

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,858 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
from collections import OrderedDict
from ...image_processor import VaeImageProcessor, PipelineImageInput
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...models import ControlNetModel, ImageProjection, UNet2DConditionModel, AutoencoderKL, ControlNetUnionModel
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...models.lora import adjust_lora_scale_text_encoder
from ...utils import (
USE_PEFT_BACKEND,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor, unwrap_module
from ...pipelines.controlnet.multicontrolnet import MultiControlNetModel
from ...configuration_utils import FrozenDict
from transformers import (
CLIPTextModel,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from ...schedulers import EulerDiscreteScheduler
from ...guiders import ClassifierFreeGuidance
from .modular_loader import StableDiffusionXLModularLoader
from ..modular_pipeline import PipelineBlock, PipelineState, AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam, ConfigSpec
import numpy as np
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"IP Adapter step that handles all the ip adapter related tasks: Load/unload ip adapter weights into unet, prepare ip adapter image embeddings, etc"
" See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)"
" for more details"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
ComponentSpec("feature_extractor", CLIPImageProcessor, config=FrozenDict({"size": 224, "crop_size": 224}), default_creation_method="from_config"),
ComponentSpec("unet", UNet2DConditionModel),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"ip_adapter_image",
PipelineImageInput,
required=True,
description="The image(s) to be used as ip adapter"
)
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("ip_adapter_embeds", type_hint=torch.Tensor, description="IP adapter image embeddings"),
OutputParam("negative_ip_adapter_embeds", type_hint=torch.Tensor, description="Negative IP adapter image embeddings")
]
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image with self -> components
@staticmethod
def encode_image(components, image, device, num_images_per_prompt, output_hidden_states=None):
dtype = next(components.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = components.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = components.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_enc_hidden_states = components.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = components.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
uncond_image_embeds = torch.zeros_like(image_embeds)
return image_embeds, uncond_image_embeds
# modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, components, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, prepare_unconditional_embeds
):
image_embeds = []
if prepare_unconditional_embeds:
negative_image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(components.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(components.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, components.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
components, single_ip_adapter_image, device, 1, output_hidden_state
)
image_embeds.append(single_image_embeds[None, :])
if prepare_unconditional_embeds:
negative_image_embeds.append(single_negative_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
if prepare_unconditional_embeds:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
negative_image_embeds.append(single_negative_image_embeds)
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
if prepare_unconditional_embeds:
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
block_state.ip_adapter_embeds = self.prepare_ip_adapter_image_embeds(
components,
ip_adapter_image=block_state.ip_adapter_image,
ip_adapter_image_embeds=None,
device=block_state.device,
num_images_per_prompt=1,
prepare_unconditional_embeds=block_state.prepare_unconditional_embeds,
)
if block_state.prepare_unconditional_embeds:
block_state.negative_ip_adapter_embeds = []
for i, image_embeds in enumerate(block_state.ip_adapter_embeds):
negative_image_embeds, image_embeds = image_embeds.chunk(2)
block_state.negative_ip_adapter_embeds.append(negative_image_embeds)
block_state.ip_adapter_embeds[i] = image_embeds
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLTextEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return(
"Text Encoder step that generate text_embeddings to guide the image generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", CLIPTextModel),
ComponentSpec("text_encoder_2", CLIPTextModelWithProjection),
ComponentSpec("tokenizer", CLIPTokenizer),
ComponentSpec("tokenizer_2", CLIPTokenizer),
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("force_zeros_for_empty_prompt", True)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("prompt"),
InputParam("prompt_2"),
InputParam("negative_prompt"),
InputParam("negative_prompt_2"),
InputParam("cross_attention_kwargs"),
InputParam("clip_skip"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [
OutputParam("prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields",description="text embeddings used to guide the image generation"),
OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative text embeddings used to guide the image generation"),
OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="pooled text embeddings used to guide the image generation"),
OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, kwargs_type="guider_input_fields", description="negative pooled text embeddings used to guide the image generation"),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
elif block_state.prompt_2 is not None and (not isinstance(block_state.prompt_2, str) and not isinstance(block_state.prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(block_state.prompt_2)}")
@staticmethod
def encode_prompt(
components,
prompt: str,
prompt_2: Optional[str] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
prepare_unconditional_embeds: bool = True,
negative_prompt: Optional[str] = None,
negative_prompt_2: Optional[str] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
device: (`torch.device`):
torch device
num_images_per_prompt (`int`):
number of images that should be generated per prompt
prepare_unconditional_embeds (`bool`):
whether to use prepare unconditional embeddings or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
"""
device = device or components._execution_device
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(components, StableDiffusionXLLoraLoaderMixin):
components._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if components.text_encoder is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder, lora_scale)
else:
scale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(components.text_encoder_2, lora_scale)
else:
scale_lora_layers(components.text_encoder_2, lora_scale)
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Define tokenizers and text encoders
tokenizers = [components.tokenizer, components.tokenizer_2] if components.tokenizer is not None else [components.tokenizer_2]
text_encoders = (
[components.text_encoder, components.text_encoder_2] if components.text_encoder is not None else [components.text_encoder_2]
)
if prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
# textual inversion: process multi-vector tokens if necessary
prompt_embeds_list = []
prompts = [prompt, prompt_2]
for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
prompt = components.maybe_convert_prompt(prompt, tokenizer)
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
# We are only ALWAYS interested in the pooled output of the final text encoder
pooled_prompt_embeds = prompt_embeds[0]
if clip_skip is None:
prompt_embeds = prompt_embeds.hidden_states[-2]
else:
# "2" because SDXL always indexes from the penultimate layer.
prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
# get unconditional embeddings for classifier free guidance
zero_out_negative_prompt = negative_prompt is None and components.config.force_zeros_for_empty_prompt
if prepare_unconditional_embeds and negative_prompt_embeds is None and zero_out_negative_prompt:
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
elif prepare_unconditional_embeds and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
uncond_tokens: List[str]
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
else:
uncond_tokens = [negative_prompt, negative_prompt_2]
negative_prompt_embeds_list = []
for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
if isinstance(components, TextualInversionLoaderMixin):
negative_prompt = components.maybe_convert_prompt(negative_prompt, tokenizer)
max_length = prompt_embeds.shape[1]
uncond_input = tokenizer(
negative_prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
negative_prompt_embeds = text_encoder(
uncond_input.input_ids.to(device),
output_hidden_states=True,
)
# We are only ALWAYS interested in the pooled output of the final text encoder
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
negative_prompt_embeds_list.append(negative_prompt_embeds)
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
if components.text_encoder_2 is not None:
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
prompt_embeds = prompt_embeds.to(dtype=components.unet.dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prepare_unconditional_embeds:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
if components.text_encoder_2 is not None:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.text_encoder_2.dtype, device=device)
else:
negative_prompt_embeds = negative_prompt_embeds.to(dtype=components.unet.dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if prepare_unconditional_embeds:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
bs_embed * num_images_per_prompt, -1
)
if components.text_encoder is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder, lora_scale)
if components.text_encoder_2 is not None:
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(components.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
# Get inputs and intermediates
block_state = self.get_block_state(state)
self.check_inputs(block_state)
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
block_state.device = components._execution_device
# Encode input prompt
block_state.text_encoder_lora_scale = (
block_state.cross_attention_kwargs.get("scale", None) if block_state.cross_attention_kwargs is not None else None
)
(
block_state.prompt_embeds,
block_state.negative_prompt_embeds,
block_state.pooled_prompt_embeds,
block_state.negative_pooled_prompt_embeds,
) = self.encode_prompt(
components,
block_state.prompt,
block_state.prompt_2,
block_state.device,
1,
block_state.prepare_unconditional_embeds,
block_state.negative_prompt,
block_state.negative_prompt_2,
prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
lora_scale=block_state.text_encoder_lora_scale,
clip_skip=block_state.clip_skip,
)
# Add outputs
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def description(self) -> str:
return (
"Vae Encoder step that encode the input image into a latent representation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam("preprocess_kwargs", type_hint=Optional[dict], description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]")]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representing the reference image for image-to-image/inpainting generation")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * components.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.preprocess_kwargs = block_state.preprocess_kwargs or {}
block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs)
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
)
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
self.add_block_state(state, block_state)
return components, state
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
model_name = "stable-diffusion-xl"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKL),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 8}),
default_creation_method="from_config"),
ComponentSpec(
"mask_processor",
VaeImageProcessor,
config=FrozenDict({"do_normalize": False, "vae_scale_factor": 8, "do_binarize": True, "do_convert_grayscale": True}),
default_creation_method="from_config"),
]
@property
def description(self) -> str:
return (
"Vae encoder step that prepares the image and mask for the inpainting process"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("height"),
InputParam("width"),
InputParam("image", required=True),
InputParam("mask_image", required=True),
InputParam("padding_mask_crop"),
]
@property
def intermediates_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"),
]
@property
def intermediates_outputs(self) -> List[OutputParam]:
return [OutputParam("image_latents", type_hint=torch.Tensor, description="The latents representation of the input image"),
OutputParam("mask", type_hint=torch.Tensor, description="The mask to use for the inpainting process"),
OutputParam("masked_image_latents", type_hint=torch.Tensor, description="The masked image latents to use for the inpainting process (only for inpainting-specifid unet)"),
OutputParam("crops_coords", type_hint=Optional[Tuple[int, int]], description="The crop coordinates to use for the preprocess/postprocess of the image and mask")]
# Modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline._encode_vae_image with self -> components
# YiYi TODO: update the _encode_vae_image so that we can use #Coped from
def _encode_vae_image(self, components, image: torch.Tensor, generator: torch.Generator):
latents_mean = latents_std = None
if hasattr(components.vae.config, "latents_mean") and components.vae.config.latents_mean is not None:
latents_mean = torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1)
if hasattr(components.vae.config, "latents_std") and components.vae.config.latents_std is not None:
latents_std = torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1)
dtype = image.dtype
if components.vae.config.force_upcast:
image = image.float()
components.vae.to(dtype=torch.float32)
if isinstance(generator, list):
image_latents = [
retrieve_latents(components.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(components.vae.encode(image), generator=generator)
if components.vae.config.force_upcast:
components.vae.to(dtype)
image_latents = image_latents.to(dtype)
if latents_mean is not None and latents_std is not None:
latents_mean = latents_mean.to(device=image_latents.device, dtype=dtype)
latents_std = latents_std.to(device=image_latents.device, dtype=dtype)
image_latents = (image_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
else:
image_latents = components.vae.config.scaling_factor * image_latents
return image_latents
# modified from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipeline.prepare_mask_latents
# do not accept do_classifier_free_guidance
def prepare_mask_latents(
self, components, mask, masked_image, batch_size, height, width, dtype, device, generator
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // components.vae_scale_factor, width // components.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image is not None and masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
masked_image_latents = None
if masked_image is not None:
if masked_image_latents is None:
masked_image = masked_image.to(device=device, dtype=dtype)
masked_image_latents = self._encode_vae_image(components, masked_image, generator=generator)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.device = components._execution_device
if block_state.padding_mask_crop is not None:
block_state.crops_coords = components.mask_processor.get_crop_region(block_state.mask_image, block_state.width, block_state.height, pad=block_state.padding_mask_crop)
block_state.resize_mode = "fill"
else:
block_state.crops_coords = None
block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess(block_state.image, height=block_state.height, width=block_state.width, crops_coords=block_state.crops_coords, resize_mode=block_state.resize_mode)
block_state.image = block_state.image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess(block_state.mask_image, height=block_state.height, width=block_state.width, resize_mode=block_state.resize_mode, crops_coords=block_state.crops_coords)
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
block_state.batch_size = block_state.image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image(components, image=block_state.image, generator=block_state.generator)
# 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components,
block_state.mask,
block_state.masked_image,
block_state.batch_size,
block_state.height,
block_state.width,
block_state.dtype,
block_state.device,
block_state.generator,
)
self.add_block_state(state, block_state)
return components, state
# auto blocks (YiYi TODO: maybe move all the auto blocks to a separate file)
# Encode
class StableDiffusionXLAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [StableDiffusionXLInpaintVaeEncoderStep, StableDiffusionXLVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return "Vae encoder step that encode the image inputs into their latent representations.\n" + \
"This is an auto pipeline block that works for both inpainting and img2img tasks.\n" + \
" - `StableDiffusionXLInpaintVaeEncoderStep` (inpaint) is used when both `mask_image` and `image` are provided.\n" + \
" - `StableDiffusionXLVaeEncoderStep` (img2img) is used when only `image` is provided."
class StableDiffusionXLAutoIPAdapterStep(AutoPipelineBlocks, ModularIPAdapterMixin):
block_classes = [StableDiffusionXLIPAdapterStep]
block_names = ["ip_adapter"]
block_trigger_inputs = ["ip_adapter_image"]
@property
def description(self):
return "Run IP Adapter step if `ip_adapter_image` is provided."

View File

@@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..modular_pipeline_utils import InsertableOrderedDict
# Import all the necessary block classes
from .denoise import (
StableDiffusionXLAutoDenoiseStep,
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDenoiseLoop,
StableDiffusionXLInpaintDenoiseLoop
)
from .before_denoise import (
StableDiffusionXLAutoBeforeDenoiseStep,
StableDiffusionXLInputStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLImg2ImgSetTimestepsStep,
StableDiffusionXLImg2ImgPrepareLatentsStep,
StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep,
StableDiffusionXLInpaintPrepareLatentsStep,
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep
)
from .encoders import (
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
StableDiffusionXLVaeEncoderStep,
StableDiffusionXLInpaintVaeEncoderStep,
StableDiffusionXLIPAdapterStep
)
from .decoders import (
StableDiffusionXLDecodeStep,
StableDiffusionXLInpaintDecodeStep,
StableDiffusionXLAutoDecodeStep
)
# YiYi notes: comment out for now, work on this later
# block mapping
TEXT2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseLoop),
("decode", StableDiffusionXLDecodeStep)
])
IMAGE2IMAGE_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseLoop),
("decode", StableDiffusionXLDecodeStep)
])
INPAINT_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("image_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLInpaintDenoiseLoop),
("decode", StableDiffusionXLInpaintDecodeStep)
])
CONTROLNET_BLOCKS = InsertableOrderedDict([
("controlnet_input", StableDiffusionXLControlNetInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
CONTROLNET_UNION_BLOCKS = InsertableOrderedDict([
("controlnet_input", StableDiffusionXLControlNetUnionInputStep),
("denoise", StableDiffusionXLControlNetDenoiseStep),
])
IP_ADAPTER_BLOCKS = InsertableOrderedDict([
("ip_adapter", StableDiffusionXLIPAdapterStep),
])
AUTO_BLOCKS = InsertableOrderedDict([
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("image_encoder", StableDiffusionXLAutoVaeEncoderStep),
("before_denoise", StableDiffusionXLAutoBeforeDenoiseStep),
("denoise", StableDiffusionXLAutoDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep)
])
SDXL_SUPPORTED_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"controlnet_union": CONTROLNET_UNION_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS
}

View File

@@ -0,0 +1,174 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
import PIL
import torch
import numpy as np
from ...loaders import StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin, ModularIPAdapterMixin
from ...image_processor import PipelineImageInput
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from ...utils import logging
from ..modular_pipeline import ModularLoader
from ..modular_pipeline_utils import InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder?
# YiYi Notes: model specific components:
## (1) it should inherit from ModularLoader
## (2) acts like a container that holds components and configs
## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents
## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin)
## (5) how to use together with Components_manager?
class StableDiffusionXLModularLoader(
ModularLoader,
StableDiffusionMixin,
TextualInversionLoaderMixin,
StableDiffusionXLLoraLoaderMixin,
ModularIPAdapterMixin,
):
@property
def default_sample_size(self):
default_sample_size = 128
if hasattr(self, "unet") and self.unet is not None:
default_sample_size = self.unet.config.sample_size
return default_sample_size
@property
def vae_scale_factor(self):
vae_scale_factor = 8
if hasattr(self, "vae") and self.vae is not None:
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
return vae_scale_factor
@property
def num_channels_unet(self):
num_channels_unet = 4
if hasattr(self, "unet") and self.unet is not None:
num_channels_unet = self.unet.config.in_channels
return num_channels_unet
@property
def num_channels_latents(self):
num_channels_latents = 4
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
return num_channels_latents
# YiYi Notes: not used yet, maintain a list of schema that can be used across all pipeline blocks
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam("prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"),
"prompt_2": InputParam("prompt_2", type_hint=Union[str, List[str]], description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2"),
"negative_prompt": InputParam("negative_prompt", type_hint=Union[str, List[str]], description="The prompt or prompts not to guide the image generation"),
"negative_prompt_2": InputParam("negative_prompt_2", type_hint=Union[str, List[str]], description="The negative prompt or prompts for text_encoder_2"),
"cross_attention_kwargs": InputParam("cross_attention_kwargs", type_hint=Optional[dict], description="Kwargs dictionary passed to the AttentionProcessor"),
"clip_skip": InputParam("clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"),
"image": InputParam("image", type_hint=PipelineImageInput, required=True, description="The image(s) to modify for img2img or inpainting"),
"mask_image": InputParam("mask_image", type_hint=PipelineImageInput, required=True, description="Mask image for inpainting, white pixels will be repainted"),
"generator": InputParam("generator", type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], description="Generator(s) for deterministic generation"),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam("num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"),
"timesteps": InputParam("timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"),
"sigmas": InputParam("sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"),
"denoising_end": InputParam("denoising_end", type_hint=Optional[float], description="Fraction of denoising process to complete before termination"),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam("strength", type_hint=float, default=0.3, description="How much to transform the reference image"),
"denoising_start": InputParam("denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"),
"latents": InputParam("latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"),
"padding_mask_crop": InputParam("padding_mask_crop", type_hint=Optional[Tuple[int, int]], description="Size of margin in crop for image and mask"),
"original_size": InputParam("original_size", type_hint=Optional[Tuple[int, int]], description="Original size of the image for SDXL's micro-conditioning"),
"target_size": InputParam("target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"),
"negative_original_size": InputParam("negative_original_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on image resolution"),
"negative_target_size": InputParam("negative_target_size", type_hint=Optional[Tuple[int, int]], description="Negative conditioning based on target resolution"),
"crops_coords_top_left": InputParam("crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Top-left coordinates for SDXL's micro-conditioning"),
"negative_crops_coords_top_left": InputParam("negative_crops_coords_top_left", type_hint=Tuple[int, int], default=(0, 0), description="Negative conditioning crop coordinates"),
"aesthetic_score": InputParam("aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"),
"negative_aesthetic_score": InputParam("negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam("output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"),
"ip_adapter_image": InputParam("ip_adapter_image", type_hint=PipelineImageInput, required=True, description="Image(s) to be used as IP adapter"),
"control_image": InputParam("control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"),
"control_guidance_start": InputParam("control_guidance_start", type_hint=Union[float, List[float]], default=0.0, description="When ControlNet starts applying"),
"control_guidance_end": InputParam("control_guidance_end", type_hint=Union[float, List[float]], default=1.0, description="When ControlNet stops applying"),
"controlnet_conditioning_scale": InputParam("controlnet_conditioning_scale", type_hint=Union[float, List[float]], default=1.0, description="Scale factor for ControlNet outputs"),
"guess_mode": InputParam("guess_mode", type_hint=bool, default=False, description="Enables ControlNet encoder to recognize input without prompts"),
"control_mode": InputParam("control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet")
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam("prompt_embeds", type_hint=torch.Tensor, required=True, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam("preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"),
"latents": InputParam("latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam("num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"),
"latent_timestep": InputParam("latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"),
"image_latents": InputParam("image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"add_time_ids": InputParam("add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"),
"negative_add_time_ids": InputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": InputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": InputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], required=True, description="Generated images")
}
SDXL_INTERMEDIATE_OUTPUTS_SCHEMA = {
"prompt_embeds": OutputParam("prompt_embeds", type_hint=torch.Tensor, description="Text embeddings used to guide image generation"),
"negative_prompt_embeds": OutputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"),
"pooled_prompt_embeds": OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pooled text embeddings"),
"negative_pooled_prompt_embeds": OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"),
"batch_size": OutputParam("batch_size", type_hint=int, description="Number of prompts"),
"dtype": OutputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"image_latents": OutputParam("image_latents", type_hint=torch.Tensor, description="Latents representing reference image"),
"mask": OutputParam("mask", type_hint=torch.Tensor, description="Mask for inpainting"),
"masked_image_latents": OutputParam("masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"),
"crops_coords": OutputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"timesteps": OutputParam("timesteps", type_hint=torch.Tensor, description="Timesteps for inference"),
"num_inference_steps": OutputParam("num_inference_steps", type_hint=int, description="Number of denoising steps"),
"latent_timestep": OutputParam("latent_timestep", type_hint=torch.Tensor, description="Initial noise level timestep"),
"add_time_ids": OutputParam("add_time_ids", type_hint=torch.Tensor, description="Time ids for conditioning"),
"negative_add_time_ids": OutputParam("negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"),
"timestep_cond": OutputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"latents": OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
"noise": OutputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"ip_adapter_embeds": OutputParam("ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"),
"negative_ip_adapter_embeds": OutputParam("negative_ip_adapter_embeds", type_hint=List[torch.Tensor], description="Negative image embeddings for IP-Adapter"),
"images": OutputParam("images", type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], description="Generated images")
}
SDXL_OUTPUTS_SCHEMA = {
"images": OutputParam("images", type_hint=Union[Tuple[Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]]], StableDiffusionXLPipelineOutput], description="The final generated images")
}

View File

@@ -0,0 +1,43 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional, Tuple, Union, Dict
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from .denoise import StableDiffusionXLAutoDenoiseStep
from .before_denoise import StableDiffusionXLAutoBeforeDenoiseStep
from .decoders import StableDiffusionXLAutoDecodeStep
from .encoders import StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLAutoPipeline(SequentialPipelineBlocks):
block_classes = [StableDiffusionXLTextEncoderStep, StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, StableDiffusionXLAutoBeforeDenoiseStep, StableDiffusionXLAutoDenoiseStep, StableDiffusionXLAutoDecodeStep]
block_names = ["text_encoder", "ip_adapter", "image_encoder", "before_denoise", "denoise", "decoder"]
@property
def description(self):
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n" + \
"- for image-to-image generation, you need to provide either `image` or `image_latents`\n" + \
"- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n" + \
"- to run the controlnet workflow, you need to provide `control_image`\n" + \
"- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n" + \
"- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n" + \
"- for text-to-image generation, all you need to provide is `prompt`"

View File

@@ -47,7 +47,6 @@ else:
"AutoPipelineForInpainting",
"AutoPipelineForText2Image",
]
_import_structure["modular_pipeline"] = ["ModularPipeline"]
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
_import_structure["ddim"] = ["DDIMPipeline"]
@@ -330,8 +329,6 @@ else:
"StableDiffusionXLInpaintPipeline",
"StableDiffusionXLInstructPix2PixPipeline",
"StableDiffusionXLPipeline",
"StableDiffusionXLModularPipeline",
"StableDiffusionXLAutoPipeline",
]
)
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
@@ -481,7 +478,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .modular_pipeline import ModularPipeline
from .pipeline_utils import (
AudioPipelineOutput,
DiffusionPipeline,
@@ -706,9 +702,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLInstructPix2PixPipeline,
StableDiffusionXLModularPipeline,
StableDiffusionXLPipeline,
StableDiffusionXLAutoPipeline,
)
from .stable_video_diffusion import StableVideoDiffusionPipeline
from .t2i_adapter import (

File diff suppressed because it is too large Load Diff

View File

@@ -331,6 +331,20 @@ def maybe_raise_or_warn(
)
# a simpler version of get_class_obj_and_candidates, it won't work with custom code
def simple_get_class_obj(library_name, class_name):
from diffusers import pipelines
is_pipeline_module = hasattr(pipelines, library_name)
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
else:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
return class_obj
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
@@ -412,7 +426,7 @@ def _get_pipeline_class(
revision=revision,
)
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
if class_obj.__name__ != "DiffusionPipeline":
return class_obj
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
@@ -839,7 +853,10 @@ def _fetch_class_library_tuple(module):
library = not_compiled_module.__module__
# retrieve class_name
class_name = not_compiled_module.__class__.__name__
if isinstance(not_compiled_module, type):
class_name = not_compiled_module.__name__
else:
class_name = not_compiled_module.__class__.__name__
return (library, class_name)

View File

@@ -1948,9 +1948,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
}
optional_components = pipeline._optional_components if hasattr(pipeline, "_optional_components") and pipeline._optional_components else []
missing_modules = (
set(expected_modules)
- set(pipeline._optional_components)
- set(optional_components)
- set(pipeline_kwargs.keys())
- set(true_optional_modules)
)

View File

@@ -29,18 +29,6 @@ else:
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
_import_structure["pipeline_stable_diffusion_xl_modular"] = [
"StableDiffusionXLControlNetDenoiseStep",
"StableDiffusionXLDecodeLatentsStep",
"StableDiffusionXLDenoiseStep",
"StableDiffusionXLInputStep",
"StableDiffusionXLModularPipeline",
"StableDiffusionXLPrepareAdditionalConditioningStep",
"StableDiffusionXLPrepareLatentsStep",
"StableDiffusionXLSetTimestepsStep",
"StableDiffusionXLTextEncoderStep",
"StableDiffusionXLAutoPipeline",
]
if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
@@ -60,18 +48,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
from .pipeline_stable_diffusion_xl_modular import (
StableDiffusionXLControlNetDenoiseStep,
StableDiffusionXLDecodeLatentsStep,
StableDiffusionXLDenoiseStep,
StableDiffusionXLInputStep,
StableDiffusionXLModularPipeline,
StableDiffusionXLPrepareAdditionalConditioningStep,
StableDiffusionXLPrepareLatentsStep,
StableDiffusionXLSetTimestepsStep,
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoPipeline,
)
try:
if not (is_transformers_available() and is_flax_available()):

View File

@@ -1388,7 +1388,7 @@ class LDMSuperResolutionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class ModularPipeline(metaclass=DummyObject):
class ModularLoader(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):

View File

@@ -2432,7 +2432,7 @@ class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionXLModularPipeline(metaclass=DummyObject):
class StableDiffusionXLModularLoader(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):

View File

@@ -15,13 +15,16 @@
"""Utilities to dynamically load objects from the Hub."""
import importlib
import signal
import inspect
import json
import os
import re
import shutil
import sys
import threading
from pathlib import Path
from types import ModuleType
from typing import Dict, Optional, Union
from urllib import request
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
_HF_REMOTE_CODE_LOCK = threading.Lock()
def get_diffusers_versions():
@@ -154,15 +159,87 @@ def check_imports(filename):
return get_relative_imports(filename)
def get_class_in_module(class_name, module_path):
def _raise_timeout_error(signum, frame):
raise ValueError(
"Loading this model requires you to execute custom code contained in the model repository on your local "
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
)
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
if trust_remote_code is None:
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
prev_sig_handler = None
try:
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
signal.alarm(TIME_OUT_REMOTE_CODE)
while trust_remote_code is None:
answer = input(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
f"Do you wish to run the custom code? [y/N] "
)
if answer.lower() in ["yes", "y", "1"]:
trust_remote_code = True
elif answer.lower() in ["no", "n", "0", ""]:
trust_remote_code = False
signal.alarm(0)
except Exception:
# OS which does not support signal.SIGALRM
raise ValueError(
f"The repository for {model_name} contains custom code which must be executed to correctly "
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
)
finally:
if prev_sig_handler is not None:
signal.signal(signal.SIGALRM, prev_sig_handler)
signal.alarm(0)
elif has_remote_code:
# For the CI which puts the timeout at 0
_raise_timeout_error(None, None)
if has_remote_code and not trust_remote_code:
raise ValueError(
f"Loading {model_name} requires you to execute the configuration file in that"
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
" set the option `trust_remote_code=True` to remove this error."
)
return trust_remote_code
def get_class_in_module(class_name, module_path, force_reload=False):
"""
Import a module on the cache directory for modules and extract a class from it.
"""
module_path = module_path.replace(os.path.sep, ".")
module = importlib.import_module(module_path)
name = os.path.normpath(module_path)
if name.endswith(".py"):
name = name[:-3]
name = name.replace(os.path.sep, ".")
module_file: Path = Path(HF_MODULES_CACHE) / module_path
with _HF_REMOTE_CODE_LOCK:
if force_reload:
sys.modules.pop(name, None)
importlib.invalidate_caches()
cached_module: Optional[ModuleType] = sys.modules.get(name)
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
module: ModuleType
if cached_module is None:
module = importlib.util.module_from_spec(module_spec)
# insert it into sys.modules before any loading begins
sys.modules[name] = module
else:
module = cached_module
module_spec.loader.exec_module(module)
if class_name is None:
return find_pipeline_class(module)
return getattr(module, class_name)
@@ -454,4 +531,4 @@ def get_class_from_dynamic_module(
revision=revision,
local_files_only=local_files_only,
)
return get_class_in_module(class_name, final_module.replace(".py", ""))
return get_class_in_module(class_name, final_module)

View File

@@ -90,6 +90,11 @@ def is_compiled_module(module) -> bool:
return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
def unwrap_module(module):
"""Unwraps a module if it was compiled with torch.compile()"""
return module._orig_mod if is_compiled_module(module) else module
def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
"""Fourier filter as introduced in FreeU (https://arxiv.org/abs/2309.11497).