mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-24 21:34:55 +08:00
Compare commits
12 Commits
single-fil
...
convert-an
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bb3bdd3cc2 | ||
|
|
5a2561f51b | ||
|
|
4343ce2c8e | ||
|
|
0ca7b68198 | ||
|
|
3cf4f9c735 | ||
|
|
40dd9cb2bd | ||
|
|
30bcda7de6 | ||
|
|
9ea62d119a | ||
|
|
a326d61118 | ||
|
|
e7696e20f9 | ||
|
|
4b89aeffe1 | ||
|
|
0a1daadef8 |
@@ -165,6 +165,25 @@ list_adapters_component_wise
|
||||
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
|
||||
```
|
||||
|
||||
## Compatibility with `torch.compile`
|
||||
|
||||
If you want to compile your model with `torch.compile` make sure to first fuse the LoRA weights into the base model and unload them.
|
||||
|
||||
```py
|
||||
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
|
||||
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
|
||||
# Fuses the LoRAs into the Unet
|
||||
pipe.fuse_lora()
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
pipe = torch.compile(pipe)
|
||||
|
||||
prompt = "toy_face of a hacker with a hoodie, pixel art"
|
||||
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
|
||||
```
|
||||
|
||||
## Fusing adapters into the model
|
||||
|
||||
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.
|
||||
|
||||
@@ -56,6 +56,60 @@ pipeline = DiffusionPipeline.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
### Load from a local file
|
||||
|
||||
Community pipelines can also be loaded from a local file if you pass a file path instead. The path to the passed directory must contain a `pipeline.py` file that contains the pipeline class in order to successfully load it.
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="./path/to/pipeline_directory/",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
### Load from a specific version
|
||||
|
||||
By default, community pipelines are loaded from the latest stable version of Diffusers. To load a community pipeline from another version, use the `custom_revision` parameter.
|
||||
|
||||
<hfoptions id="version">
|
||||
<hfoption id="main">
|
||||
|
||||
For example, to load from the `main` branch:
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
custom_revision="main",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="older version">
|
||||
|
||||
For example, to load from a previous version of Diffusers like `v0.25.0`:
|
||||
|
||||
```py
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"runwayml/stable-diffusion-v1-5",
|
||||
custom_pipeline="clip_guided_stable_diffusion",
|
||||
custom_revision="v0.25.0",
|
||||
clip_model=clip_model,
|
||||
feature_extractor=feature_extractor,
|
||||
use_safetensors=True,
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
For more information about community pipelines, take a look at the [Community pipelines](custom_pipeline_examples) guide for how to use them and if you're interested in adding a community pipeline check out the [How to contribute a community pipeline](contribute_pipeline) guide!
|
||||
|
||||
## Community components
|
||||
|
||||
@@ -376,18 +376,14 @@ After training, LoRA weights can be loaded very easily into the original pipelin
|
||||
load the original pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
from diffusers import DiffusionPipeline
|
||||
pipe = DiffusionPipeline.from_pretrained("base-model-name").to("cuda")
|
||||
```
|
||||
|
||||
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
|
||||
Next, we can load the adapter layers into the pipeline with the [`load_lora_weights` function](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters#lora).
|
||||
|
||||
```python
|
||||
pipe.unet.load_attn_procs("patrickvonplaten/lora_dreambooth_dog_example")
|
||||
pipe.load_lora_weights("path-to-the-lora-checkpoint")
|
||||
```
|
||||
|
||||
Finally, we can run the model in inference.
|
||||
|
||||
@@ -67,8 +67,8 @@ DATASET_NAME_MAPPING = {
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
images: list = None,
|
||||
repo_folder: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
|
||||
@@ -56,7 +56,9 @@ check_min_version("0.27.0.dev0")
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
|
||||
def save_model_card(
|
||||
repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
|
||||
@@ -66,12 +66,12 @@ DATASET_NAME_MAPPING = {
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
validation_prompt=None,
|
||||
base_model=str,
|
||||
dataset_name=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
images: list = None,
|
||||
validation_prompt: str = None,
|
||||
base_model: str = None,
|
||||
dataset_name: str = None,
|
||||
repo_folder: str = None,
|
||||
vae_path: str = None,
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
|
||||
62
scripts/convert_animatelcm_lora_to_diffusers.py
Normal file
62
scripts/convert_animatelcm_lora_to_diffusers.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import argparse
|
||||
import re
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
def convert_lora(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
for k, v in original_state_dict.items():
|
||||
if "pos_encoder" in k:
|
||||
continue
|
||||
if "alpha" in k:
|
||||
continue
|
||||
else:
|
||||
diffusers_key = (
|
||||
k.replace(".norms.0", ".norm1")
|
||||
.replace(".norms.1", ".norm2")
|
||||
.replace(".ff_norm", ".norm3")
|
||||
.replace(".attention_blocks.0", ".attn1")
|
||||
.replace(".attention_blocks.1", ".attn2")
|
||||
.replace(".temporal_transformer", "")
|
||||
.replace("lora_unet_", "")
|
||||
)
|
||||
diffusers_key = diffusers_key.replace("to_out_0_", "to_out_")
|
||||
diffusers_key = diffusers_key.replace("mid_block_", "mid_block.")
|
||||
diffusers_key = diffusers_key.replace("attn1_", "attn1.processor.")
|
||||
diffusers_key = diffusers_key.replace("attn2_", "attn2.processor.")
|
||||
diffusers_key = diffusers_key.replace(".lora_", "_lora.")
|
||||
diffusers_key = re.sub(r'_(\d+)_', r'.\1.', diffusers_key)
|
||||
|
||||
converted_state_dict[diffusers_key] = v
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
state_dict = load_file(args.ckpt_path)
|
||||
|
||||
if "state_dict" in state_dict.keys():
|
||||
state_dict = state_dict["state_dict"]
|
||||
|
||||
converted_state_dict = convert_lora(state_dict)
|
||||
|
||||
# convert to new format
|
||||
output_dict = {}
|
||||
for module_name, params in converted_state_dict.items():
|
||||
if type(params) is not torch.Tensor:
|
||||
continue
|
||||
output_dict.update({f"unet.{module_name}": params})
|
||||
|
||||
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")
|
||||
@@ -38,6 +38,9 @@ class FromOriginalVAEMixin:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config_file (`str`, *optional*):
|
||||
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
@@ -65,6 +68,13 @@ class FromOriginalVAEMixin:
|
||||
image_size (`int`, *optional*, defaults to 512):
|
||||
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
|
||||
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
|
||||
scaling_factor (`float`, *optional*, defaults to 0.18215):
|
||||
The component-wise standard deviation of the trained latent space computed using the first batch of the
|
||||
training set. This is used to scale the latent space to have unit variance when training the diffusion
|
||||
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
||||
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
|
||||
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
|
||||
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
use_safetensors (`bool`, *optional*, defaults to `None`):
|
||||
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
|
||||
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
|
||||
@@ -92,6 +102,7 @@ class FromOriginalVAEMixin:
|
||||
"""
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -103,6 +114,13 @@ class FromOriginalVAEMixin:
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
|
||||
if (config_file is not None) and (original_config_file is not None):
|
||||
raise ValueError(
|
||||
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||
)
|
||||
|
||||
original_config_file = original_config_file or config_file
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
@@ -118,7 +136,10 @@ class FromOriginalVAEMixin:
|
||||
)
|
||||
|
||||
image_size = kwargs.pop("image_size", None)
|
||||
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
|
||||
scaling_factor = kwargs.pop("scaling_factor", None)
|
||||
component = create_diffusers_vae_model_from_ldm(
|
||||
class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor
|
||||
)
|
||||
vae = component["vae"]
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
@@ -38,6 +38,9 @@ class FromOriginalControlNetMixin:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
config_file (`str`, *optional*):
|
||||
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
|
||||
https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
@@ -89,6 +92,7 @@ class FromOriginalControlNetMixin:
|
||||
```
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_file = kwargs.pop("config_file", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -100,6 +104,12 @@ class FromOriginalControlNetMixin:
|
||||
use_safetensors = kwargs.pop("use_safetensors", True)
|
||||
|
||||
class_name = cls.__name__
|
||||
if (config_file is not None) and (original_config_file is not None):
|
||||
raise ValueError(
|
||||
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
|
||||
)
|
||||
|
||||
original_config_file = config_file or original_config_file
|
||||
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
|
||||
pretrained_model_link_or_path=pretrained_model_link_or_path,
|
||||
class_name=class_name,
|
||||
|
||||
@@ -48,6 +48,7 @@ def build_sub_model_components(
|
||||
load_safety_checker=False,
|
||||
model_type=None,
|
||||
image_size=None,
|
||||
torch_dtype=None,
|
||||
**kwargs,
|
||||
):
|
||||
if component_name in pipeline_components:
|
||||
@@ -96,7 +97,7 @@ def build_sub_model_components(
|
||||
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
|
||||
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
safety_checker = None
|
||||
|
||||
@@ -48,7 +48,6 @@ if is_transformers_available():
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -175,6 +174,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
|
||||
}
|
||||
|
||||
LDM_VAE_KEY = "first_stage_model."
|
||||
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
|
||||
LDM_UNET_KEY = "model.diffusion_model."
|
||||
LDM_CONTROLNET_KEY = "control_model."
|
||||
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
|
||||
@@ -518,7 +518,10 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
|
||||
Creates a config for the diffusers based on the config of the LDM model.
|
||||
"""
|
||||
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
|
||||
scaling_factor = scaling_factor or original_config["model"]["params"]["scale_factor"]
|
||||
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
|
||||
scaling_factor = original_config["model"]["params"]["scale_factor"]
|
||||
elif scaling_factor is None:
|
||||
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
|
||||
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
@@ -870,8 +873,17 @@ def create_diffusers_controlnet_model_from_ldm(
|
||||
controlnet = ControlNetModel(**diffusers_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_controlnet_checkpoint.items():
|
||||
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(controlnet, diffusers_format_controlnet_checkpoint)
|
||||
if controlnet._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in controlnet._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
|
||||
|
||||
@@ -1034,8 +1046,17 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict)
|
||||
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
||||
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
||||
@@ -1116,8 +1137,17 @@ def create_text_encoder_from_open_clip_checkpoint(
|
||||
text_model_dict[diffusers_key] = checkpoint[key]
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in text_model_dict.items():
|
||||
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict)
|
||||
if text_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in text_model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else:
|
||||
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
||||
@@ -1164,8 +1194,17 @@ def create_diffusers_unet_model_from_ldm(
|
||||
unet = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_unet_checkpoint.items():
|
||||
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint)
|
||||
if unet._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in unet._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
unet.load_state_dict(diffusers_format_unet_checkpoint)
|
||||
|
||||
@@ -1173,7 +1212,7 @@ def create_diffusers_unet_model_from_ldm(
|
||||
|
||||
|
||||
def create_diffusers_vae_model_from_ldm(
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125
|
||||
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None
|
||||
):
|
||||
# import here to avoid circular imports
|
||||
from ..models import AutoencoderKL
|
||||
@@ -1188,8 +1227,17 @@ def create_diffusers_vae_model_from_ldm(
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
|
||||
if is_accelerate_available():
|
||||
for param_name, param in diffusers_format_vae_checkpoint.items():
|
||||
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint)
|
||||
if vae._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in vae._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
else:
|
||||
vae.load_state_dict(diffusers_format_vae_checkpoint)
|
||||
|
||||
@@ -1226,7 +1274,9 @@ def create_text_encoders_and_tokenizers_from_ldm(
|
||||
try:
|
||||
config_name = "openai/clip-vit-large-patch14"
|
||||
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
|
||||
config_name, checkpoint, local_files_only=local_files_only
|
||||
config_name,
|
||||
checkpoint,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)
|
||||
|
||||
|
||||
@@ -981,10 +981,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
custom_revision (`str`, *optional*, defaults to `"main"`):
|
||||
custom_revision (`str`, *optional*):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
||||
`revision` when loading a custom pipeline from the Hub. It can be a 🤗 Diffusers version when loading a
|
||||
custom pipeline from GitHub, otherwise it defaults to `"main"` when loading from the Hub.
|
||||
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers version.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
||||
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
||||
|
||||
@@ -151,7 +151,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
sample_max_value: float = 1.0,
|
||||
algorithm_type: str = "dpmsolver++",
|
||||
solver_type: str = "midpoint",
|
||||
lower_order_final: bool = True,
|
||||
lower_order_final: bool = False,
|
||||
use_karras_sigmas: Optional[bool] = False,
|
||||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
|
||||
lambda_min_clipped: float = -float("inf"),
|
||||
@@ -233,7 +233,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
orders = [1, 2, 3] * (steps // 3) + [1, 2]
|
||||
elif order == 2:
|
||||
if steps % 2 == 0:
|
||||
orders = [1, 2] * (steps // 2)
|
||||
orders = [1, 2] * (steps // 2 - 1) + [1, 1]
|
||||
else:
|
||||
orders = [1, 2] * (steps // 2) + [1]
|
||||
elif order == 1:
|
||||
@@ -320,7 +320,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
|
||||
logger.warn(
|
||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
|
||||
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
|
||||
)
|
||||
self.register_to_config(lower_order_final=True)
|
||||
|
||||
|
||||
@@ -37,8 +37,10 @@ from diffusers import (
|
||||
EulerDiscreteScheduler,
|
||||
LCMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
T2IAdapter,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
|
||||
@@ -2175,7 +2177,7 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
release_memory(pipeline)
|
||||
|
||||
def test_canny_lora(self):
|
||||
def test_controlnet_canny_lora(self):
|
||||
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
|
||||
|
||||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
||||
@@ -2199,6 +2201,34 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
|
||||
assert np.allclose(original_image, expected_image, atol=1e-04)
|
||||
release_memory(pipe)
|
||||
|
||||
def test_sdxl_t2i_adapter_canny_lora(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16).to(
|
||||
"cpu"
|
||||
)
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "toy"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
image_slice = images[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226])
|
||||
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
|
||||
|
||||
@nightly
|
||||
def test_sequential_fuse_unfuse(self):
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
|
||||
@@ -35,6 +35,7 @@ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
enable_full_determinism,
|
||||
is_flaky,
|
||||
numpy_cosine_similarity_distance,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
@@ -259,6 +260,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
]
|
||||
assert processors == [True] * len(processors)
|
||||
|
||||
@is_flaky
|
||||
def test_multi(self):
|
||||
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
@@ -275,7 +277,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
inputs["ip_adapter_image"] = [ip_adapter_image, [ip_adapter_image] * 2]
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
expected_slice = np.array([0.1704, 0.1296, 0.1272, 0.2212, 0.1514, 0.1479, 0.4172, 0.4263, 0.4360])
|
||||
expected_slice = np.array([0.5234, 0.5352, 0.5625, 0.5713, 0.5947, 0.6206, 0.5786, 0.6187, 0.6494])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
@@ -672,34 +672,6 @@ class AdapterSDXLPipelineSlowTests(unittest.TestCase):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_canny_lora(self):
|
||||
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16).to(
|
||||
"cpu"
|
||||
)
|
||||
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
adapter=adapter,
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
)
|
||||
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
|
||||
pipe.enable_model_cpu_offload()
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
prompt = "toy"
|
||||
image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
|
||||
)
|
||||
|
||||
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
|
||||
|
||||
assert images[0].shape == (768, 512, 3)
|
||||
|
||||
image_slice = images[0, -3:, -3:, -1].flatten()
|
||||
expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226])
|
||||
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
|
||||
|
||||
def test_download_ckpt_diff_format_is_same(self):
|
||||
ckpt_path = (
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"
|
||||
|
||||
Reference in New Issue
Block a user