mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
70 Commits
attn-refac
...
modular-di
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6a509ba862 | ||
|
|
96795afc72 | ||
|
|
12650e1393 | ||
|
|
addaad013c | ||
|
|
485f8d1758 | ||
|
|
cff0fd6260 | ||
|
|
8ddb20bfb8 | ||
|
|
e5089d702b | ||
|
|
2c3e4eafa8 | ||
|
|
c7020df2cf | ||
|
|
4bed3e306e | ||
|
|
00a3bc9d6c | ||
|
|
ccb35acd81 | ||
|
|
00cae4e857 | ||
|
|
b3fb4188f5 | ||
|
|
71df1581f7 | ||
|
|
d046cf7d35 | ||
|
|
68a5185c86 | ||
|
|
6e2fe26bfd | ||
|
|
77b5fa59c5 | ||
|
|
a226920b52 | ||
|
|
7007f72409 | ||
|
|
a6804de4a2 | ||
|
|
7f897a9fc4 | ||
|
|
0966663d2a | ||
|
|
fb78f4f12d | ||
|
|
2220af6940 | ||
|
|
7a34832d52 | ||
|
|
e973de64f9 | ||
|
|
db94ca882d | ||
|
|
6985906a2e | ||
|
|
54f410db6c | ||
|
|
c12a05b9c1 | ||
|
|
2e0f5c86cc | ||
|
|
1d63306295 | ||
|
|
6c93626f6f | ||
|
|
72c5bf07c8 | ||
|
|
ed59f90f15 | ||
|
|
a09ca7f27e | ||
|
|
8c02572e16 | ||
|
|
27dde51de8 | ||
|
|
10d4a775f1 | ||
|
|
72d9a81d99 | ||
|
|
4fa85c7963 | ||
|
|
806e8e66fb | ||
|
|
0b90051db8 | ||
|
|
b305c779b2 | ||
|
|
2b3cd2d39c | ||
|
|
bc3d1c9ee6 | ||
|
|
e50d614636 | ||
|
|
a8df0f1ffb | ||
|
|
ace53e2d2f | ||
|
|
ffc2992fc2 | ||
|
|
c70a285c2c | ||
|
|
8b811feece | ||
|
|
37e8dc7a59 | ||
|
|
024a9f5de3 | ||
|
|
005195c23e | ||
|
|
6742f160df | ||
|
|
540d303250 | ||
|
|
f1b3036ca1 | ||
|
|
46ec1743a2 | ||
|
|
70272b1108 | ||
|
|
2b6dcbfa1d | ||
|
|
af9572d759 | ||
|
|
ddea157979 | ||
|
|
ad3f9a26c0 | ||
|
|
e8d0980f9f | ||
|
|
52a7f1cb97 | ||
|
|
33f85fadf6 |
@@ -239,6 +239,7 @@ else:
|
||||
"KarrasVePipeline",
|
||||
"LDMPipeline",
|
||||
"LDMSuperResolutionPipeline",
|
||||
"ModularPipeline",
|
||||
"PNDMPipeline",
|
||||
"RePaintPipeline",
|
||||
"ScoreSdeVePipeline",
|
||||
@@ -493,10 +494,12 @@ else:
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLPAGImg2ImgPipeline",
|
||||
"StableDiffusionXLPAGInpaintPipeline",
|
||||
"StableDiffusionXLPAGPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
"StableUnCLIPImg2ImgPipeline",
|
||||
"StableUnCLIPPipeline",
|
||||
"StableVideoDiffusionPipeline",
|
||||
@@ -834,6 +837,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
LDMSuperResolutionPipeline,
|
||||
ModularPipeline,
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
@@ -1066,10 +1070,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLPAGImg2ImgPipeline,
|
||||
StableDiffusionXLPAGInpaintPipeline,
|
||||
StableDiffusionXLPAGPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
StableVideoDiffusionPipeline,
|
||||
|
||||
745
src/diffusers/guider.py
Normal file
745
src/diffusers/guider.py
Normal file
@@ -0,0 +1,745 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
PAGCFGIdentitySelfAttnProcessor2_0,
|
||||
PAGIdentitySelfAttnProcessor2_0,
|
||||
)
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class CFGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
|
||||
"""
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
# a flag to disable CFG, e.g. we disable it for LCM and use a guidance scale embedding instead
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is not provided in guider_kwargs")
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is not provided in guider_kwargs")
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
self._disable_guidance = disable_guidance
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
pass
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 2:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size :]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Classifier-Free Guidance (CFG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 2
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
|
||||
single tensor or a list of tensors. It must have the same length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
else:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_classifier_free_guidance:
|
||||
return model_output
|
||||
|
||||
noise_pred_uncond, noise_pred_text = model_output.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
|
||||
|
||||
class PAGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with CFG (Classifier-Free Guidance).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pag_applied_layers: Union[str, List[str]],
|
||||
pag_attn_processors: Tuple[AttentionProcessor, AttentionProcessor] = (
|
||||
PAGCFGIdentitySelfAttnProcessor2_0(),
|
||||
PAGIdentitySelfAttnProcessor2_0(),
|
||||
),
|
||||
):
|
||||
r"""
|
||||
Set the the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
|
||||
|
||||
Args:
|
||||
pag_applied_layers (`str` or `List[str]`):
|
||||
One or more strings identifying the layer names, or a simple regex for matching multiple layers, where
|
||||
PAG is to be applied. A few ways of expected usage are as follows:
|
||||
- Single layers specified as - "blocks.{layer_index}"
|
||||
- Multiple layers as a list - ["blocks.{layers_index_1}", "blocks.{layer_index_2}", ...]
|
||||
- Multiple layers as a block name - "mid"
|
||||
- Multiple layers as regex - "blocks.({layer_index_1}|{layer_index_2})"
|
||||
pag_attn_processors:
|
||||
(`Tuple[AttentionProcessor, AttentionProcessor]`, defaults to `(PAGCFGIdentitySelfAttnProcessor2_0(),
|
||||
PAGIdentitySelfAttnProcessor2_0())`): A tuple of two attention processors. The first attention
|
||||
processor is for PAG with Classifier-free guidance enabled (conditional and unconditional). The second
|
||||
attention processor is for PAG with CFG disabled (unconditional only).
|
||||
"""
|
||||
|
||||
if not isinstance(pag_applied_layers, list):
|
||||
pag_applied_layers = [pag_applied_layers]
|
||||
if pag_attn_processors is not None:
|
||||
if not isinstance(pag_attn_processors, tuple) or len(pag_attn_processors) != 2:
|
||||
raise ValueError("Expected a tuple of two attention processors")
|
||||
|
||||
for i in range(len(pag_applied_layers)):
|
||||
if not isinstance(pag_applied_layers[i], str):
|
||||
raise ValueError(
|
||||
f"Expected either a string or a list of string but got type {type(pag_applied_layers[i])}"
|
||||
)
|
||||
|
||||
self.pag_applied_layers = pag_applied_layers
|
||||
self._pag_attn_processors = pag_attn_processors
|
||||
|
||||
def _set_pag_attn_processor(self, model, pag_applied_layers, do_classifier_free_guidance):
|
||||
r"""
|
||||
Set the attention processor for the PAG layers.
|
||||
"""
|
||||
pag_attn_processors = self._pag_attn_processors
|
||||
pag_attn_proc = pag_attn_processors[0] if do_classifier_free_guidance else pag_attn_processors[1]
|
||||
|
||||
def is_self_attn(module: nn.Module) -> bool:
|
||||
r"""
|
||||
Check if the module is self-attention module based on its name.
|
||||
"""
|
||||
return isinstance(module, Attention) and not module.is_cross_attention
|
||||
|
||||
def is_fake_integral_match(layer_id, name):
|
||||
layer_id = layer_id.split(".")[-1]
|
||||
name = name.split(".")[-1]
|
||||
return layer_id.isnumeric() and name.isnumeric() and layer_id == name
|
||||
|
||||
for layer_id in pag_applied_layers:
|
||||
# for each PAG layer input, we find corresponding self-attention layers in the unet model
|
||||
target_modules = []
|
||||
|
||||
for name, module in model.named_modules():
|
||||
# Identify the following simple cases:
|
||||
# (1) Self Attention layer existing
|
||||
# (2) Whether the module name matches pag layer id even partially
|
||||
# (3) Make sure it's not a fake integral match if the layer_id ends with a number
|
||||
# For example, blocks.1, blocks.10 should be differentiable if layer_id="blocks.1"
|
||||
if (
|
||||
is_self_attn(module)
|
||||
and re.search(layer_id, name) is not None
|
||||
and not is_fake_integral_match(layer_id, name)
|
||||
):
|
||||
logger.debug(f"Applying PAG to layer: {name}")
|
||||
target_modules.append(module)
|
||||
|
||||
if len(target_modules) == 0:
|
||||
raise ValueError(f"Cannot find PAG layer to set attention processor for: {layer_id}")
|
||||
|
||||
for module in target_modules:
|
||||
module.processor = pag_attn_proc
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def do_perturbed_attention_guidance(self):
|
||||
return self._pag_scale > 0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def do_pag_adaptive_scaling(self):
|
||||
return self._pag_adaptive_scale > 0 and self._pag_scale > 0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def pag_scale(self):
|
||||
return self._pag_scale
|
||||
|
||||
@property
|
||||
def pag_adaptive_scale(self):
|
||||
return self._pag_adaptive_scale
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
pag_scale = guider_kwargs.get("pag_scale", 3.0)
|
||||
pag_adaptive_scale = guider_kwargs.get("pag_adaptive_scale", 0.0)
|
||||
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is a required argument for PAGGuider")
|
||||
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is a required argument for PAGGuider")
|
||||
|
||||
self._pag_scale = pag_scale
|
||||
self._pag_adaptive_scale = pag_adaptive_scale
|
||||
self._guidance_scale = guidance_scale
|
||||
self._disable_guidance = disable_guidance
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
if not hasattr(pipeline, "original_attn_proc") or pipeline.original_attn_proc is None:
|
||||
pipeline.original_attn_proc = pipeline.unet.attn_processors
|
||||
self._set_pag_attn_processor(
|
||||
model=pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer,
|
||||
pag_applied_layers=self.pag_applied_layers,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
if (
|
||||
self.do_perturbed_attention_guidance
|
||||
and hasattr(pipeline, "original_attn_proc")
|
||||
and pipeline.original_attn_proc is not None
|
||||
):
|
||||
pipeline.unet.set_attn_processor(pipeline.original_attn_proc)
|
||||
pipeline.original_attn_proc = None
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Perturbed Attention Guidance (PAG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 3
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 3:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size : self.batch_size * 2]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]):
|
||||
The negative conditional input. It can be a single tensor or a list of tensors. It must have the same
|
||||
length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
|
||||
if self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_perturbed_attention_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_perturbed_attention_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
|
||||
cond = torch.cat([cond] * 2, dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
else:
|
||||
prepared_input.append(cond)
|
||||
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return cond_input
|
||||
|
||||
cond_input = torch.cat([cond_input] * 2, dim=0)
|
||||
if self.do_classifier_free_guidance:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
else:
|
||||
return cond_input
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(negative_cond_input)} and {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_perturbed_attention_guidance:
|
||||
return model_output
|
||||
|
||||
if self.do_pag_adaptive_scaling:
|
||||
pag_scale = max(self._pag_scale - self._pag_adaptive_scale * (1000 - timestep), 0)
|
||||
else:
|
||||
pag_scale = self._pag_scale
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text, noise_pred_perturb = model_output.chunk(3)
|
||||
noise_pred = (
|
||||
noise_pred_uncond
|
||||
+ self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
+ pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
)
|
||||
else:
|
||||
noise_pred_text, noise_pred_perturb = model_output.chunk(2)
|
||||
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
|
||||
return noise_pred
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
class APGGuider:
|
||||
"""
|
||||
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
|
||||
"""
|
||||
|
||||
def normalized_guidance(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: MomentumBuffer = None,
|
||||
norm_threshold: float = 0.0,
|
||||
eta: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion
|
||||
Models](https://arxiv.org/pdf/2410.02416)
|
||||
"""
|
||||
diff = pred_cond - pred_uncond
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
|
||||
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
|
||||
return pred_guided
|
||||
|
||||
@property
|
||||
def adaptive_projected_guidance_momentum(self):
|
||||
return self._adaptive_projected_guidance_momentum
|
||||
|
||||
@property
|
||||
def adaptive_projected_guidance_rescale_factor(self):
|
||||
return self._adaptive_projected_guidance_rescale_factor
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0 and not self._disable_guidance
|
||||
|
||||
@property
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
|
||||
disable_guidance = guider_kwargs.get("disable_guidance", False)
|
||||
guidance_scale = guider_kwargs.get("guidance_scale", None)
|
||||
if guidance_scale is None:
|
||||
raise ValueError("guidance_scale is not provided in guider_kwargs")
|
||||
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
|
||||
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
|
||||
"adaptive_projected_guidance_rescale_factor", 15.0
|
||||
)
|
||||
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
|
||||
batch_size = guider_kwargs.get("batch_size", None)
|
||||
if batch_size is None:
|
||||
raise ValueError("batch_size is not provided in guider_kwargs")
|
||||
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
|
||||
self._guidance_scale = guidance_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._batch_size = batch_size
|
||||
self._disable_guidance = disable_guidance
|
||||
if adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
|
||||
else:
|
||||
self.momentum_buffer = None
|
||||
self.scheduler = pipeline.scheduler
|
||||
|
||||
def reset_guider(self, pipeline):
|
||||
pass
|
||||
|
||||
def maybe_update_guider(self, pipeline, timestep):
|
||||
pass
|
||||
|
||||
def maybe_update_input(self, pipeline, cond_input):
|
||||
pass
|
||||
|
||||
def _maybe_split_prepared_input(self, cond):
|
||||
"""
|
||||
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
|
||||
|
||||
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
|
||||
It determines whether to split the input based on its batch size relative to the expected batch size.
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to process.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- The negative conditional input (uncond_input)
|
||||
- The positive conditional input (cond_input)
|
||||
"""
|
||||
if cond.shape[0] == self.batch_size * 2:
|
||||
neg_cond = cond[0 : self.batch_size]
|
||||
cond = cond[self.batch_size :]
|
||||
return neg_cond, cond
|
||||
elif cond.shape[0] == self.batch_size:
|
||||
return cond, cond
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape: {cond.shape}")
|
||||
|
||||
def _is_prepared_input(self, cond):
|
||||
"""
|
||||
Check if the input is already prepared for Classifier-Free Guidance (CFG).
|
||||
|
||||
Args:
|
||||
cond (torch.Tensor): The conditional input tensor to check.
|
||||
|
||||
Returns:
|
||||
bool: True if the input is already prepared, False otherwise.
|
||||
"""
|
||||
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
|
||||
|
||||
return cond_tensor.shape[0] == self.batch_size * 2
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
cond_input: Union[torch.Tensor, List[torch.Tensor]],
|
||||
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Prepare the input for CFG.
|
||||
|
||||
Args:
|
||||
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
|
||||
The conditional input. It can be a single tensor or a
|
||||
list of tensors. It must have the same length as `negative_cond_input`.
|
||||
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
|
||||
single tensor or a list of tensors. It must have the same length as `cond_input`.
|
||||
|
||||
Returns:
|
||||
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
|
||||
"""
|
||||
|
||||
# we check if cond_input already has CFG applied, and split if it is the case.
|
||||
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
|
||||
if isinstance(cond_input, list):
|
||||
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
|
||||
else:
|
||||
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
|
||||
|
||||
if not self._is_prepared_input(cond_input) and self.do_classifier_free_guidance and negative_cond_input is None:
|
||||
raise ValueError(
|
||||
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
|
||||
)
|
||||
|
||||
if isinstance(cond_input, (list, tuple)):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
|
||||
if len(negative_cond_input) != len(cond_input):
|
||||
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input = []
|
||||
for neg_cond, cond in zip(negative_cond_input, cond_input):
|
||||
if neg_cond.shape[0] != cond.shape[0]:
|
||||
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
|
||||
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
|
||||
return prepared_input
|
||||
|
||||
elif isinstance(cond_input, torch.Tensor):
|
||||
if not self.do_classifier_free_guidance:
|
||||
return cond_input
|
||||
else:
|
||||
return torch.cat([negative_cond_input, cond_input], dim=0)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(cond_input)}")
|
||||
|
||||
def apply_guidance(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if not self.do_classifier_free_guidance:
|
||||
return model_output
|
||||
|
||||
if latents is None:
|
||||
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")
|
||||
|
||||
sigma = self.scheduler.sigmas[self.scheduler.step_index]
|
||||
noise_pred = latents - sigma * model_output
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = self.normalized_guidance(
|
||||
noise_pred_text,
|
||||
noise_pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.adaptive_projected_guidance_rescale_factor,
|
||||
)
|
||||
noise_pred = (latents - noise_pred) / sigma
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
||||
return noise_pred
|
||||
@@ -84,6 +84,7 @@ if is_torch_available():
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
"ModularIPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
@@ -102,6 +103,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
ModularIPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
|
||||
@@ -356,6 +356,265 @@ class IPAdapterMixin:
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
class ModularIPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to "
|
||||
f"{len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
|
||||
# remove hidden encoder
|
||||
if self.unet is None:
|
||||
return
|
||||
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@@ -441,7 +441,7 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
if _pipeline is not None and hasattr(_pipeline, "hf_device_map") and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
@@ -491,6 +491,7 @@ class LoraBaseMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
@@ -713,8 +714,10 @@ class LoraBaseMixin:
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
|
||||
model = getattr(self, component, None)
|
||||
if model is None:
|
||||
logger.warning(f"Model {component} not found in pipeline.")
|
||||
continue
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
component_adapter_weights = weights.pop(component, None)
|
||||
|
||||
@@ -636,7 +636,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
unet_config=self.unet.config if hasattr(self, "unet") else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -644,37 +644,40 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if hasattr(self, "unet"):
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if hasattr(self, "text_encoder"):
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix=self.text_encoder_name,
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if hasattr(self, "text_encoder_2"):
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix=f"{self.text_encoder_name}_2",
|
||||
lora_scale=self.lora_scale,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
|
||||
@@ -408,6 +408,7 @@ class UNet2DConditionLoadersMixin:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
def save_attn_procs(
|
||||
|
||||
@@ -47,6 +47,7 @@ else:
|
||||
"AutoPipelineForInpainting",
|
||||
"AutoPipelineForText2Image",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = ["ModularPipeline"]
|
||||
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
|
||||
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
|
||||
_import_structure["ddim"] = ["DDIMPipeline"]
|
||||
@@ -329,6 +330,8 @@ else:
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLInstructPix2PixPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["stable_diffusion_diffedit"] = ["StableDiffusionDiffEditPipeline"]
|
||||
@@ -478,6 +481,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .modular_pipeline import ModularPipeline
|
||||
from .pipeline_utils import (
|
||||
AudioPipelineOutput,
|
||||
DiffusionPipeline,
|
||||
@@ -702,7 +706,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLInstructPix2PixPipeline,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
)
|
||||
from .stable_video_diffusion import StableVideoDiffusionPipeline
|
||||
from .t2i_adapter import (
|
||||
|
||||
@@ -246,14 +246,15 @@ def _get_connected_pipeline(pipeline_cls):
|
||||
return _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, pipeline_cls.__name__, throw_error_if_not_exist=False)
|
||||
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
def get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
def _get_model(pipeline_class_name):
|
||||
for task_mapping in SUPPORTED_TASKS_MAPPINGS:
|
||||
for model_name, pipeline in task_mapping.items():
|
||||
if pipeline.__name__ == pipeline_class_name:
|
||||
return model_name
|
||||
|
||||
model_name = get_model(pipeline_class_name)
|
||||
|
||||
def _get_task_class(mapping, pipeline_class_name, throw_error_if_not_exist: bool = True):
|
||||
model_name = _get_model(pipeline_class_name)
|
||||
|
||||
if model_name is not None:
|
||||
task_class = mapping.get(model_name, None)
|
||||
|
||||
609
src/diffusers/pipelines/components_manager.py
Normal file
609
src/diffusers/pipelines/components_manager.py
Normal file
@@ -0,0 +1,609 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import OrderedDict
|
||||
from itertools import combinations
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
logging,
|
||||
)
|
||||
from ..models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import ModelHook, add_hook_to_module, remove_hook_from_module
|
||||
from accelerate.state import PartialState
|
||||
from accelerate.utils import send_to_device
|
||||
from accelerate.utils.memory import clear_device_cache
|
||||
from accelerate.utils.modeling import convert_file_size_to_int
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# YiYi Notes: copied from modeling_utils.py (decide later where to put this)
|
||||
def get_memory_footprint(self, return_buffers=True):
|
||||
r"""
|
||||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to
|
||||
benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch
|
||||
discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
||||
|
||||
Arguments:
|
||||
return_buffers (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are
|
||||
tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm
|
||||
layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
||||
"""
|
||||
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
||||
if return_buffers:
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
||||
mem = mem + mem_bufs
|
||||
return mem
|
||||
|
||||
|
||||
class CustomOffloadHook(ModelHook):
|
||||
"""
|
||||
A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are
|
||||
on the given device. Optionally offloads other models to the CPU before the forward pass is called.
|
||||
|
||||
Args:
|
||||
execution_device(`str`, `int` or `torch.device`, *optional*):
|
||||
The device on which the model should be executed. Will default to the MPS device if it's available, then
|
||||
GPU 0 if there is a GPU, and finally to the CPU.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_device: Optional[Union[str, int, torch.device]] = None,
|
||||
other_hooks: Optional[List["UserCustomOffloadHook"]] = None,
|
||||
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
||||
):
|
||||
self.execution_device = execution_device if execution_device is not None else PartialState().default_device
|
||||
self.other_hooks = other_hooks
|
||||
self.offload_strategy = offload_strategy
|
||||
self.model_id = None
|
||||
|
||||
def set_strategy(self, offload_strategy: "AutoOffloadStrategy"):
|
||||
self.offload_strategy = offload_strategy
|
||||
|
||||
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
||||
"""
|
||||
Add a hook to the list of hooks to consider for offloading.
|
||||
"""
|
||||
if self.other_hooks is None:
|
||||
self.other_hooks = []
|
||||
self.other_hooks.append(hook)
|
||||
|
||||
def init_hook(self, module):
|
||||
return module.to("cpu")
|
||||
|
||||
def pre_forward(self, module, *args, **kwargs):
|
||||
if module.device != self.execution_device:
|
||||
if self.other_hooks is not None:
|
||||
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
|
||||
# offload all other hooks
|
||||
start_time = time.perf_counter()
|
||||
if self.offload_strategy is not None:
|
||||
hooks_to_offload = self.offload_strategy(
|
||||
hooks=hooks_to_offload,
|
||||
model_id=self.model_id,
|
||||
model=module,
|
||||
execution_device=self.execution_device,
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
logger.info(
|
||||
f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds"
|
||||
)
|
||||
|
||||
for hook in hooks_to_offload:
|
||||
logger.info(
|
||||
f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu"
|
||||
)
|
||||
hook.offload()
|
||||
|
||||
if hooks_to_offload:
|
||||
clear_device_cache()
|
||||
module.to(self.execution_device)
|
||||
return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
|
||||
|
||||
|
||||
class UserCustomOffloadHook:
|
||||
"""
|
||||
A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of
|
||||
the hook or remove it entirely.
|
||||
"""
|
||||
|
||||
def __init__(self, model_id, model, hook):
|
||||
self.model_id = model_id
|
||||
self.model = model
|
||||
self.hook = hook
|
||||
|
||||
def offload(self):
|
||||
self.hook.init_hook(self.model)
|
||||
|
||||
def attach(self):
|
||||
add_hook_to_module(self.model, self.hook)
|
||||
self.hook.model_id = self.model_id
|
||||
|
||||
def remove(self):
|
||||
remove_hook_from_module(self.model)
|
||||
self.hook.model_id = None
|
||||
|
||||
def add_other_hook(self, hook: "UserCustomOffloadHook"):
|
||||
self.hook.add_other_hook(hook)
|
||||
|
||||
|
||||
def custom_offload_with_hook(
|
||||
model_id: str,
|
||||
model: torch.nn.Module,
|
||||
execution_device: Union[str, int, torch.device] = None,
|
||||
offload_strategy: Optional["AutoOffloadStrategy"] = None,
|
||||
):
|
||||
hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy)
|
||||
user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook)
|
||||
user_hook.attach()
|
||||
return user_hook
|
||||
|
||||
|
||||
class AutoOffloadStrategy:
|
||||
"""
|
||||
Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on
|
||||
the available memory on the device.
|
||||
"""
|
||||
|
||||
def __init__(self, memory_reserve_margin="3GB"):
|
||||
self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin)
|
||||
|
||||
def __call__(self, hooks, model_id, model, execution_device):
|
||||
if len(hooks) == 0:
|
||||
return []
|
||||
|
||||
current_module_size = get_memory_footprint(model)
|
||||
|
||||
mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0]
|
||||
mem_on_device = mem_on_device - self.memory_reserve_margin
|
||||
if current_module_size < mem_on_device:
|
||||
return []
|
||||
|
||||
min_memory_offload = current_module_size - mem_on_device
|
||||
logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory")
|
||||
|
||||
# exlucde models that's not currently loaded on the device
|
||||
module_sizes = dict(
|
||||
sorted(
|
||||
{hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
def search_best_candidate(module_sizes, min_memory_offload):
|
||||
"""
|
||||
search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a
|
||||
minimum memory offload size. the combination of models should add up to the smallest modulesize that is
|
||||
larger than `min_memory_offload`
|
||||
"""
|
||||
model_ids = list(module_sizes.keys())
|
||||
best_candidate = None
|
||||
best_size = float("inf")
|
||||
for r in range(1, len(model_ids) + 1):
|
||||
for candidate_model_ids in combinations(model_ids, r):
|
||||
candidate_size = sum(
|
||||
module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids
|
||||
)
|
||||
if candidate_size < min_memory_offload:
|
||||
continue
|
||||
else:
|
||||
if best_candidate is None or candidate_size < best_size:
|
||||
best_candidate = candidate_model_ids
|
||||
best_size = candidate_size
|
||||
|
||||
return best_candidate
|
||||
|
||||
best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload)
|
||||
|
||||
if best_offload_model_ids is None:
|
||||
# if no combination is found, meaning that we cannot meet the memory requirement, offload all models
|
||||
logger.warning("no combination of models to offload to cpu is found, offloading all models")
|
||||
hooks_to_offload = hooks
|
||||
else:
|
||||
hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids]
|
||||
|
||||
return hooks_to_offload
|
||||
|
||||
|
||||
class ComponentsManager:
|
||||
def __init__(self):
|
||||
self.components = OrderedDict()
|
||||
self.added_time = OrderedDict() # Store when components were added
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def add(self, name, component):
|
||||
if name in self.components:
|
||||
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
|
||||
self.components[name] = component
|
||||
self.added_time[name] = time.time()
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
def remove(self, name):
|
||||
if name not in self.components:
|
||||
logger.warning(f"Component '{name}' not found in ComponentsManager")
|
||||
return
|
||||
|
||||
self.components.pop(name)
|
||||
self.added_time.pop(name)
|
||||
|
||||
if self._auto_offload_enabled:
|
||||
self.enable_auto_cpu_offload(self._auto_offload_device)
|
||||
|
||||
# YiYi TODO: looking into improving the search pattern
|
||||
def get(self, names: Union[str, List[str]]):
|
||||
"""
|
||||
Get components by name with simple pattern matching.
|
||||
|
||||
Args:
|
||||
names: Component name(s) or pattern(s)
|
||||
Patterns:
|
||||
- "unet" : exact match
|
||||
- "!unet" : everything except exact match "unet"
|
||||
- "base_*" : everything starting with "base_"
|
||||
- "!base_*" : everything NOT starting with "base_"
|
||||
- "*unet*" : anything containing "unet"
|
||||
- "!*unet*" : anything NOT containing "unet"
|
||||
- "refiner|vae|unet" : anything containing any of these terms
|
||||
- "!refiner|vae|unet" : anything NOT containing any of these terms
|
||||
|
||||
Returns:
|
||||
Single component if names is str and matches one component,
|
||||
dict of components if names matches multiple components or is a list
|
||||
"""
|
||||
if isinstance(names, str):
|
||||
# Check if this is a "not" pattern
|
||||
is_not_pattern = names.startswith('!')
|
||||
if is_not_pattern:
|
||||
names = names[1:] # Remove the ! prefix
|
||||
|
||||
# Handle OR patterns (containing |)
|
||||
if '|' in names:
|
||||
terms = names.split('|')
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if any((term in name) != is_not_pattern for term in terms) # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing any of {terms}: {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing any of {terms}: {list(matches.keys())}")
|
||||
|
||||
# Exact match
|
||||
elif names in self.components:
|
||||
if is_not_pattern:
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name != names
|
||||
}
|
||||
logger.info(f"Getting all components except '{names}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting component: {names}")
|
||||
return self.components[names]
|
||||
|
||||
# Prefix match (ends with *)
|
||||
elif names.endswith('*'):
|
||||
prefix = names[:-1]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if name.startswith(prefix) != is_not_pattern # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}")
|
||||
|
||||
# Contains match (starts with *)
|
||||
elif names.startswith('*'):
|
||||
search = names[1:-1] if names.endswith('*') else names[1:]
|
||||
matches = {
|
||||
name: comp for name, comp in self.components.items()
|
||||
if (search in name) != is_not_pattern # Flip condition if not pattern
|
||||
}
|
||||
if is_not_pattern:
|
||||
logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}")
|
||||
else:
|
||||
logger.info(f"Getting components containing '{search}': {list(matches.keys())}")
|
||||
|
||||
else:
|
||||
raise ValueError(f"Component '{names}' not found in ComponentsManager")
|
||||
|
||||
if not matches:
|
||||
raise ValueError(f"No components found matching pattern '{names}'")
|
||||
return matches if len(matches) > 1 else next(iter(matches.values()))
|
||||
|
||||
elif isinstance(names, list):
|
||||
results = {}
|
||||
for name in names:
|
||||
result = self.get(name)
|
||||
if isinstance(result, dict):
|
||||
results.update(result)
|
||||
else:
|
||||
results[name] = result
|
||||
logger.info(f"Getting multiple components: {list(results.keys())}")
|
||||
return results
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid type for names: {type(names)}")
|
||||
|
||||
def enable_auto_cpu_offload(self, device: Union[str, int, torch.device]="cuda", memory_reserve_margin="3GB"):
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
remove_hook_from_module(component, recurse=True)
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
all_hooks = []
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy)
|
||||
all_hooks.append(hook)
|
||||
|
||||
for hook in all_hooks:
|
||||
other_hooks = [h for h in all_hooks if h is not hook]
|
||||
for other_hook in other_hooks:
|
||||
if other_hook.hook.execution_device == hook.hook.execution_device:
|
||||
hook.add_other_hook(other_hook)
|
||||
|
||||
self.model_hooks = all_hooks
|
||||
self._auto_offload_enabled = True
|
||||
self._auto_offload_device = device
|
||||
|
||||
def disable_auto_cpu_offload(self):
|
||||
if self.model_hooks is None:
|
||||
self._auto_offload_enabled = False
|
||||
return
|
||||
|
||||
for hook in self.model_hooks:
|
||||
hook.offload()
|
||||
hook.remove()
|
||||
if self.model_hooks:
|
||||
clear_device_cache()
|
||||
self.model_hooks = None
|
||||
self._auto_offload_enabled = False
|
||||
|
||||
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get comprehensive information about a component.
|
||||
|
||||
Args:
|
||||
name: Name of the component to get info for
|
||||
fields: Optional field(s) to return. Can be a string for single field or list of fields.
|
||||
If None, returns all fields.
|
||||
|
||||
Returns:
|
||||
Dictionary containing requested component metadata.
|
||||
If fields is specified, returns only those fields.
|
||||
If a single field is requested as string, returns just that field's value.
|
||||
"""
|
||||
if name not in self.components:
|
||||
raise ValueError(f"Component '{name}' not found in ComponentsManager")
|
||||
|
||||
component = self.components[name]
|
||||
|
||||
# Build complete info dict first
|
||||
info = {
|
||||
"model_id": name,
|
||||
"added_time": self.added_time[name],
|
||||
}
|
||||
|
||||
# Additional info for torch.nn.Module components
|
||||
if isinstance(component, torch.nn.Module):
|
||||
info.update({
|
||||
"class_name": component.__class__.__name__,
|
||||
"size_gb": get_memory_footprint(component) / (1024**3),
|
||||
"adapters": None, # Default to None
|
||||
})
|
||||
|
||||
# Get adapters if applicable
|
||||
if hasattr(component, "peft_config"):
|
||||
info["adapters"] = list(component.peft_config.keys())
|
||||
|
||||
# Check for IP-Adapter scales
|
||||
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
|
||||
processors = copy.deepcopy(component.attn_processors)
|
||||
# First check if any processor is an IP-Adapter
|
||||
processor_types = [v.__class__.__name__ for v in processors.values()]
|
||||
if any("IPAdapter" in ptype for ptype in processor_types):
|
||||
# Then get scales only from IP-Adapter processors
|
||||
scales = {
|
||||
k: v.scale
|
||||
for k, v in processors.items()
|
||||
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
|
||||
}
|
||||
if scales:
|
||||
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
|
||||
|
||||
# If fields specified, filter info
|
||||
if fields is not None:
|
||||
if isinstance(fields, str):
|
||||
# Single field requested, return just that value
|
||||
return {fields: info.get(fields)}
|
||||
else:
|
||||
# List of fields requested, return dict with just those fields
|
||||
return {k: v for k, v in info.items() if k in fields}
|
||||
|
||||
return info
|
||||
|
||||
def __repr__(self):
|
||||
col_widths = {
|
||||
"id": max(15, max(len(id) for id in self.components.keys())),
|
||||
"class": max(25, max(len(component.__class__.__name__) for component in self.components.values())),
|
||||
"device": 10,
|
||||
"dtype": 15,
|
||||
"size": 10,
|
||||
}
|
||||
|
||||
# Create the header lines
|
||||
sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
||||
dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n"
|
||||
|
||||
output = "Components:\n" + sep_line
|
||||
|
||||
# Separate components into models and others
|
||||
models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
||||
others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)}
|
||||
|
||||
# Models section
|
||||
if models:
|
||||
output += "Models:\n" + dash_line
|
||||
# Column headers
|
||||
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
|
||||
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
|
||||
output += dash_line
|
||||
|
||||
# Model entries
|
||||
for name, component in models.items():
|
||||
info = self.get_model_info(name)
|
||||
device = str(getattr(component, "device", "N/A"))
|
||||
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
|
||||
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
|
||||
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other components section
|
||||
if others:
|
||||
if models: # Add extra newline if we had models section
|
||||
output += "\n"
|
||||
output += "Other Components:\n" + dash_line
|
||||
# Column headers for other components
|
||||
output += f"{'Component ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}}\n"
|
||||
output += dash_line
|
||||
|
||||
# Other component entries
|
||||
for name, component in others.items():
|
||||
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
|
||||
output += dash_line
|
||||
|
||||
# Add additional component info
|
||||
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
|
||||
for name in self.components:
|
||||
info = self.get_model_info(name)
|
||||
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
|
||||
output += f"\n{name}:\n"
|
||||
if info.get("adapters") is not None:
|
||||
output += f" Adapters: {info['adapters']}\n"
|
||||
if info.get("ip_adapter"):
|
||||
output += f" IP-Adapter: Enabled\n"
|
||||
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
|
||||
|
||||
return output
|
||||
|
||||
def add_from_pretrained(self, pretrained_model_name_or_path, prefix: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
Load components from a pretrained model and add them to the manager.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str): The path or identifier of the pretrained model
|
||||
prefix (str, optional): Prefix to add to all component names loaded from this model.
|
||||
If provided, components will be named as "{prefix}_{component_name}"
|
||||
**kwargs: Additional arguments to pass to DiffusionPipeline.from_pretrained()
|
||||
"""
|
||||
from ..pipelines.pipeline_utils import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
for name, component in pipe.components.items():
|
||||
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
# Add prefix if specified
|
||||
component_name = f"{prefix}_{name}" if prefix else name
|
||||
|
||||
if component_name not in self.components:
|
||||
self.add(component_name, component)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Component '{component_name}' already exists in ComponentsManager and will not be added. To add it, either:\n"
|
||||
f"1. remove the existing component with remove('{component_name}')\n"
|
||||
f"2. Use a different prefix: add_from_pretrained(..., prefix='{prefix}_2')"
|
||||
)
|
||||
|
||||
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Summarizes a dictionary by finding common prefixes that share the same value.
|
||||
|
||||
For a dictionary with dot-separated keys like:
|
||||
{
|
||||
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
|
||||
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
|
||||
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
|
||||
}
|
||||
|
||||
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
|
||||
{
|
||||
'down_blocks': [0.6],
|
||||
'up_blocks': [0.3]
|
||||
}
|
||||
"""
|
||||
# First group by values - convert lists to tuples to make them hashable
|
||||
value_to_keys = {}
|
||||
for key, value in d.items():
|
||||
value_tuple = tuple(value) if isinstance(value, list) else value
|
||||
if value_tuple not in value_to_keys:
|
||||
value_to_keys[value_tuple] = []
|
||||
value_to_keys[value_tuple].append(key)
|
||||
|
||||
def find_common_prefix(keys: List[str]) -> str:
|
||||
"""Find the shortest common prefix among a list of dot-separated keys."""
|
||||
if not keys:
|
||||
return ""
|
||||
if len(keys) == 1:
|
||||
return keys[0]
|
||||
|
||||
# Split all keys into parts
|
||||
key_parts = [k.split('.') for k in keys]
|
||||
|
||||
# Find how many initial parts are common
|
||||
common_length = 0
|
||||
for parts in zip(*key_parts):
|
||||
if len(set(parts)) == 1: # All parts at this position are the same
|
||||
common_length += 1
|
||||
else:
|
||||
break
|
||||
|
||||
if common_length == 0:
|
||||
return ""
|
||||
|
||||
# Return the common prefix
|
||||
return '.'.join(key_parts[0][:common_length])
|
||||
|
||||
# Create summary by finding common prefixes for each value group
|
||||
summary = {}
|
||||
for value_tuple, keys in value_to_keys.items():
|
||||
prefix = find_common_prefix(keys)
|
||||
if prefix: # Only add if we found a common prefix
|
||||
# Convert tuple back to list if it was originally a list
|
||||
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
|
||||
summary[prefix] = value
|
||||
else:
|
||||
summary[""] = value # Use empty string if no common prefix
|
||||
|
||||
return summary
|
||||
@@ -912,12 +912,6 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -931,6 +925,11 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -867,12 +867,6 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -886,6 +880,11 @@ class StableDiffusionXLControlNetUnionImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -609,12 +609,6 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -628,6 +622,11 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
1704
src/diffusers/pipelines/modular_pipeline.py
Normal file
1704
src/diffusers/pipelines/modular_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -917,12 +917,6 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -936,6 +930,11 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -707,12 +707,6 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -726,6 +720,11 @@ class StableDiffusionXLPAGImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
@@ -412,7 +412,7 @@ def _get_pipeline_class(
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
if class_obj.__name__ != "DiffusionPipeline":
|
||||
if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline":
|
||||
return class_obj
|
||||
|
||||
diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0])
|
||||
|
||||
@@ -427,7 +427,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
@@ -444,6 +444,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
|
||||
)
|
||||
|
||||
|
||||
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
||||
pipeline_is_offloaded = any(module_is_offloaded(module) for _, module in self.components.items())
|
||||
if pipeline_is_offloaded and device_type in ["cuda", "xpu"]:
|
||||
@@ -1119,9 +1120,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
||||
automatically detect the available accelerator and use.
|
||||
"""
|
||||
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
@@ -1245,7 +1248,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = hasattr(self, "hf_device_map") and self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
|
||||
@@ -29,6 +29,18 @@ else:
|
||||
_import_structure["pipeline_stable_diffusion_xl_img2img"] = ["StableDiffusionXLImg2ImgPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
|
||||
_import_structure["pipeline_stable_diffusion_xl_modular"] = [
|
||||
"StableDiffusionXLControlNetDenoiseStep",
|
||||
"StableDiffusionXLDecodeLatentsStep",
|
||||
"StableDiffusionXLDenoiseStep",
|
||||
"StableDiffusionXLInputStep",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
"StableDiffusionXLPrepareAdditionalConditioningStep",
|
||||
"StableDiffusionXLPrepareLatentsStep",
|
||||
"StableDiffusionXLSetTimestepsStep",
|
||||
"StableDiffusionXLTextEncoderStep",
|
||||
"StableDiffusionXLAutoPipeline",
|
||||
]
|
||||
|
||||
if is_transformers_available() and is_flax_available():
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
@@ -48,6 +60,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
|
||||
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
|
||||
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
|
||||
from .pipeline_stable_diffusion_xl_modular import (
|
||||
StableDiffusionXLControlNetDenoiseStep,
|
||||
StableDiffusionXLDecodeLatentsStep,
|
||||
StableDiffusionXLDenoiseStep,
|
||||
StableDiffusionXLInputStep,
|
||||
StableDiffusionXLModularPipeline,
|
||||
StableDiffusionXLPrepareAdditionalConditioningStep,
|
||||
StableDiffusionXLPrepareLatentsStep,
|
||||
StableDiffusionXLSetTimestepsStep,
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLAutoPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_flax_available()):
|
||||
|
||||
@@ -695,12 +695,6 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
||||
)
|
||||
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
|
||||
# Offload text encoder if `enable_model_cpu_offload` was enabled
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.text_encoder_2.to("cpu")
|
||||
@@ -714,6 +708,11 @@ class StableDiffusionXLImg2ImgPipeline(
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
latents_mean = latents_std = None
|
||||
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
|
||||
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
if self.vae.config.force_upcast:
|
||||
image = image.float()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1388,6 +1388,21 @@ class LDMSuperResolutionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PNDMPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2432,6 +2432,21 @@ class StableDiffusionXLInstructPix2PixPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLModularPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class StableDiffusionXLPAGImg2ImgPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user