mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
14 Commits
v0.33.0
...
fix-timeou
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17d06fe3f1 | ||
|
|
511d738121 | ||
|
|
d1a3979b3b | ||
|
|
ea5a6a8b7c | ||
|
|
b8093e6665 | ||
|
|
e121d0ef67 | ||
|
|
0365bd1c17 | ||
|
|
31c4f24fc1 | ||
|
|
0efdf411fb | ||
|
|
450dc48a2c | ||
|
|
77b4f66b9e | ||
|
|
68663f8a17 | ||
|
|
ffda8735be | ||
|
|
0706786e53 |
@@ -53,7 +53,12 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
}
|
||||
|
||||
for i in range(num_down_blocks):
|
||||
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
||||
resnets = [
|
||||
key
|
||||
for key in down_blocks[i]
|
||||
if f"down.{i}" in key and f"down.{i}.downsample" not in key and "attn" not in key
|
||||
]
|
||||
attentions = [key for key in down_blocks[i] if f"down.{i}.attn" in key]
|
||||
|
||||
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
||||
@@ -67,6 +72,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"down.{i}.attn", "new": f"down_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
@@ -85,8 +94,11 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
for i in range(num_up_blocks):
|
||||
block_id = num_up_blocks - 1 - i
|
||||
resnets = [
|
||||
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
||||
key
|
||||
for key in up_blocks[block_id]
|
||||
if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key and "attn" not in key
|
||||
]
|
||||
attentions = [key for key in up_blocks[block_id] if f"up.{block_id}.attn" in key]
|
||||
|
||||
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
||||
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
||||
@@ -100,6 +112,10 @@ def custom_convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
paths = renew_vae_attention_paths(attentions)
|
||||
meta_path = {"old": f"up.{block_id}.attn", "new": f"up_blocks.{i}.attentions"}
|
||||
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
||||
|
||||
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
||||
num_mid_res_blocks = 2
|
||||
for i in range(1, num_mid_res_blocks + 1):
|
||||
|
||||
@@ -1608,3 +1608,64 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_musubi_wan_lora_to_diffusers(state_dict):
|
||||
# https://github.com/kohya-ss/musubi-tuner
|
||||
converted_state_dict = {}
|
||||
original_state_dict = {k[len("lora_unet_") :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_blocks = len({k.split("blocks_")[1].split("_")[0] for k in original_state_dict})
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
|
||||
def get_alpha_scales(down_weight, key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = original_state_dict.pop(key + ".alpha").item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for i in range(num_blocks):
|
||||
# Self-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_self_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_self_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn1.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
# Cross-attention
|
||||
for o, c in zip(["q", "k", "v", "o"], ["to_q", "to_k", "to_v", "to_out.0"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
if is_i2v_lora:
|
||||
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_cross_attn_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_cross_attn_{o}")
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
# FFN
|
||||
for o, c in zip(["ffn_0", "ffn_2"], ["net.0.proj", "net.2"]):
|
||||
down_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_down.weight")
|
||||
up_weight = original_state_dict.pop(f"blocks_{i}_{o}.lora_up.weight")
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, f"blocks_{i}_{o}")
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_A.weight"] = down_weight * scale_down
|
||||
converted_state_dict[f"blocks.{i}.ffn.{c}.lora_B.weight"] = up_weight * scale_up
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {original_state_dict.keys()=}")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -42,6 +42,7 @@ from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_wan_lora_to_diffusers,
|
||||
@@ -4794,6 +4795,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
if any(k.startswith("diffusion_model.") for k in state_dict):
|
||||
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
||||
elif any(k.startswith("lora_unet_") for k in state_dict):
|
||||
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
|
||||
@@ -177,6 +177,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
||||
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
||||
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
||||
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
||||
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
||||
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
||||
@@ -638,7 +639,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
model_type = "flux-schnell"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
if checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
||||
model_type = "ltx-video-0.9.5"
|
||||
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
||||
model_type = "ltx-video-0.9.1"
|
||||
else:
|
||||
model_type = "ltx-video"
|
||||
@@ -2403,13 +2406,41 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_095_RENAME_DICT = {
|
||||
# decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
# encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# common
|
||||
"last_time_embedder": "time_embedder",
|
||||
"last_scale_shift_table": "scale_shift_table",
|
||||
}
|
||||
|
||||
VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_,
|
||||
"per_channel_statistics.mean-of-means": remove_keys_,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_,
|
||||
}
|
||||
|
||||
if "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
||||
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
||||
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
|
||||
@@ -350,8 +350,14 @@ def create_vae_diffusers_config(original_config, image_size: int):
|
||||
_ = original_config["model"]["params"]["first_stage_config"]["params"]["embed_dim"]
|
||||
|
||||
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
|
||||
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
||||
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
||||
down_block_types = [
|
||||
"DownEncoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnDownEncoderBlock2D"
|
||||
for i, _ in enumerate(block_out_channels)
|
||||
]
|
||||
up_block_types = [
|
||||
"UpDecoderBlock2D" if image_size // 2**i not in vae_params["attn_resolutions"] else "AttnUpDecoderBlock2D"
|
||||
for i, _ in enumerate(block_out_channels)
|
||||
][::-1]
|
||||
|
||||
config = {
|
||||
"sample_size": image_size,
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from transformers import AutoTokenizer, UMT5EncoderModel
|
||||
@@ -24,7 +23,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -40,6 +39,9 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
import html
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import ftfy
|
||||
import PIL
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -26,7 +25,7 @@ from ...image_processor import PipelineImageInput
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -42,6 +41,9 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
|
||||
@@ -16,7 +16,6 @@ import html
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
from PIL import Image
|
||||
@@ -26,7 +25,7 @@ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import WanLoraLoaderMixin
|
||||
from ...models import AutoencoderKLWan, WanTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
@@ -42,6 +41,9 @@ else:
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
if is_ftfy_available():
|
||||
import ftfy
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
|
||||
@@ -101,18 +101,20 @@ _onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
||||
if _onnx_available:
|
||||
candidates = (
|
||||
"onnxruntime",
|
||||
"onnxruntime-cann",
|
||||
"onnxruntime-directml",
|
||||
"ort_nightly_directml",
|
||||
"onnxruntime-gpu",
|
||||
"ort_nightly_gpu",
|
||||
"onnxruntime-directml",
|
||||
"onnxruntime-openvino",
|
||||
"ort_nightly_directml",
|
||||
"onnxruntime-rocm",
|
||||
"onnxruntime-migraphx",
|
||||
"onnxruntime-openvino",
|
||||
"onnxruntime-qnn",
|
||||
"onnxruntime-rocm",
|
||||
"onnxruntime-training",
|
||||
"onnxruntime-vitisai",
|
||||
)
|
||||
_onnxruntime_version = None
|
||||
# For the metadata, we have to look for both onnxruntime and onnxruntime-gpu
|
||||
# For the metadata, we have to look for both onnxruntime and onnxruntime-x
|
||||
for pkg in candidates:
|
||||
try:
|
||||
_onnxruntime_version = importlib_metadata.version(pkg)
|
||||
|
||||
@@ -33,6 +33,7 @@ from diffusers import (
|
||||
)
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
load_image,
|
||||
nightly,
|
||||
@@ -455,11 +456,54 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
images = pipe("A pokemon with blue eyes.", output_type="np", generator=generator, num_inference_steps=2).images
|
||||
|
||||
images = images[0, -3:, -3:, -1].flatten()
|
||||
image_slice = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array(
|
||||
[
|
||||
0.6544,
|
||||
0.6127,
|
||||
0.5397,
|
||||
0.6845,
|
||||
0.6047,
|
||||
0.5469,
|
||||
0.6349,
|
||||
0.5906,
|
||||
0.5382,
|
||||
]
|
||||
),
|
||||
("cuda", 7): np.array(
|
||||
[
|
||||
0.7406,
|
||||
0.699,
|
||||
0.5963,
|
||||
0.7493,
|
||||
0.7045,
|
||||
0.6096,
|
||||
0.6886,
|
||||
0.6388,
|
||||
0.583,
|
||||
]
|
||||
),
|
||||
("cuda", 8): np.array(
|
||||
[
|
||||
0.6542,
|
||||
0.61253,
|
||||
0.5396,
|
||||
0.6843,
|
||||
0.6044,
|
||||
0.5468,
|
||||
0.6349,
|
||||
0.5905,
|
||||
0.5381,
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(expected, images)
|
||||
max_diff = numpy_cosine_similarity_distance(expected_slice, image_slice)
|
||||
assert max_diff < 1e-4
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
|
||||
@@ -260,6 +260,31 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
return modules_to_save
|
||||
|
||||
def check_if_adapters_added_correctly(
|
||||
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
|
||||
):
|
||||
if text_lora_config is not None:
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if denoiser_lora_config is not None:
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
else:
|
||||
denoiser = None
|
||||
|
||||
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
return pipe, denoiser
|
||||
|
||||
def test_simple_inference(self):
|
||||
"""
|
||||
Tests a simple inference and makes sure it works as expected
|
||||
@@ -289,16 +314,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -381,22 +397,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -459,16 +460,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -506,15 +498,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
pipe.fuse_lora()
|
||||
# Fusing should still keep the LoRA layers
|
||||
@@ -546,19 +530,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -593,18 +565,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -655,22 +616,20 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
|
||||
# supports missing layers (PR#8324).
|
||||
state_dict = {
|
||||
f"text_encoder.{module_name}": param
|
||||
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
|
||||
if "text_model.encoder.layers.4" not in module_name
|
||||
}
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
|
||||
state_dict = {}
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
|
||||
# supports missing layers (PR#8324).
|
||||
state_dict = {
|
||||
f"text_encoder.{module_name}": param
|
||||
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
|
||||
if "text_model.encoder.layers.4" not in module_name
|
||||
}
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
state_dict.update(
|
||||
{
|
||||
f"text_encoder_2.{module_name}": param
|
||||
@@ -694,7 +653,7 @@ class PeftLoraLoaderMixinTests:
|
||||
"Removing adapters should change the output",
|
||||
)
|
||||
|
||||
def test_simple_inference_save_pretrained(self):
|
||||
def test_simple_inference_save_pretrained_with_text_lora(self):
|
||||
"""
|
||||
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
|
||||
"""
|
||||
@@ -708,16 +667,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -726,10 +676,11 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
|
||||
pipe_from_pretrained.to(torch_device)
|
||||
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
|
||||
"Lora not correctly set in text encoder",
|
||||
)
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
@@ -759,22 +710,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -820,22 +756,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(
|
||||
@@ -879,22 +800,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
|
||||
@@ -932,22 +838,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.unload_lora_weights()
|
||||
# unloading should remove the LoRA layers
|
||||
@@ -983,22 +874,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
|
||||
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -1104,6 +980,8 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
def test_wrong_adapter_name_raises_error(self):
|
||||
adapter_name = "adapter-1"
|
||||
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -1111,20 +989,9 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as err_context:
|
||||
pipe.set_adapters("test")
|
||||
@@ -1132,10 +999,11 @@ class PeftLoraLoaderMixinTests:
|
||||
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
|
||||
|
||||
# test this works.
|
||||
pipe.set_adapters("adapter-1")
|
||||
pipe.set_adapters(adapter_name)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_multiple_wrong_adapter_name_raises_error(self):
|
||||
adapter_name = "adapter-1"
|
||||
scheduler_cls = self.scheduler_classes[0]
|
||||
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
|
||||
pipe = self.pipeline_class(**components)
|
||||
@@ -1143,33 +1011,22 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(
|
||||
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
|
||||
)
|
||||
|
||||
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
|
||||
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
|
||||
|
||||
wrong_components = sorted(set(scale_with_wrong_components.keys()))
|
||||
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
|
||||
self.assertTrue(msg in str(cap_logger.out))
|
||||
|
||||
# test this works.
|
||||
pipe.set_adapters("adapter-1")
|
||||
pipe.set_adapters(adapter_name)
|
||||
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
@@ -1804,20 +1661,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
lora_loadable_components = self.pipeline_class._lora_loadable_modules
|
||||
if "text_encoder_2" in lora_loadable_components:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -1908,18 +1752,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
|
||||
@@ -2011,22 +1844,7 @@ class PeftLoraLoaderMixinTests:
|
||||
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(output_no_lora.shape == self.output_shape)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
lora_scale = 0.5
|
||||
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
|
||||
@@ -2211,22 +2029,7 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe = pipe.to(torch_device, dtype=compute_dtype)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
|
||||
)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config)
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
if self.has_two_text_encoders or self.has_three_text_encoders:
|
||||
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
|
||||
pipe.text_encoder_2.add_adapter(text_lora_config)
|
||||
self.assertTrue(
|
||||
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
|
||||
)
|
||||
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
|
||||
|
||||
if storage_dtype is not None:
|
||||
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
@@ -187,7 +187,7 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit
|
||||
super().test_sequential_cpu_offload_forward_pass(expected_max_diff=0.008)
|
||||
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
super().test_dict_tuple_outputs_equivalent(expected_max_difference=0.008)
|
||||
super().test_dict_tuple_outputs_equivalent(expected_max_difference=0.009)
|
||||
|
||||
def test_save_load_optional_components(self):
|
||||
super().test_save_load_optional_components(expected_max_difference=0.008)
|
||||
|
||||
@@ -34,6 +34,7 @@ from diffusers import (
|
||||
from diffusers.image_processor import IPAdapterMaskProcessor
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
is_flaky,
|
||||
@@ -664,7 +665,50 @@ class IPAdapterSDXLIntegrationTests(IPAdapterNightlyTestsMixin):
|
||||
images = pipeline(**inputs).images
|
||||
image_slice = images[0, :3, :3, -1].flatten()
|
||||
|
||||
expected_slice = np.array([0.2323, 0.1026, 0.1338, 0.0638, 0.0662, 0.0000, 0.0000, 0.0000, 0.0199])
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array(
|
||||
[
|
||||
0.2520,
|
||||
0.1050,
|
||||
0.1510,
|
||||
0.0997,
|
||||
0.0893,
|
||||
0.0019,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0210,
|
||||
]
|
||||
),
|
||||
("cuda", 7): np.array(
|
||||
[
|
||||
0.2323,
|
||||
0.1026,
|
||||
0.1338,
|
||||
0.0638,
|
||||
0.0662,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0199,
|
||||
]
|
||||
),
|
||||
("cuda", 8): np.array(
|
||||
[
|
||||
0.2518,
|
||||
0.1059,
|
||||
0.1553,
|
||||
0.0977,
|
||||
0.0852,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0000,
|
||||
0.0220,
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
|
||||
assert max_diff < 5e-4
|
||||
|
||||
@@ -37,6 +37,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.utils.testing_utils import (
|
||||
Expectations,
|
||||
backend_empty_cache,
|
||||
backend_max_memory_allocated,
|
||||
backend_reset_max_memory_allocated,
|
||||
@@ -866,7 +867,37 @@ class StableDiffusionInpaintPipelineAsymmetricAutoencoderKLSlowTests(unittest.Te
|
||||
image_slice = image[0, 253:256, 253:256, -1].flatten()
|
||||
|
||||
assert image.shape == (1, 512, 512, 3)
|
||||
expected_slice = np.array([0.1343, 0.1406, 0.1440, 0.1504, 0.1729, 0.0989, 0.1807, 0.2822, 0.1179])
|
||||
expected_slices = Expectations(
|
||||
{
|
||||
("xpu", 3): np.array(
|
||||
[
|
||||
0.2063,
|
||||
0.1731,
|
||||
0.1553,
|
||||
0.1741,
|
||||
0.1772,
|
||||
0.1077,
|
||||
0.2109,
|
||||
0.2407,
|
||||
0.1243,
|
||||
]
|
||||
),
|
||||
("cuda", 7): np.array(
|
||||
[
|
||||
0.1343,
|
||||
0.1406,
|
||||
0.1440,
|
||||
0.1504,
|
||||
0.1729,
|
||||
0.0989,
|
||||
0.1807,
|
||||
0.2822,
|
||||
0.1179,
|
||||
]
|
||||
),
|
||||
}
|
||||
)
|
||||
expected_slice = expected_slices.get_expectation()
|
||||
|
||||
assert np.abs(expected_slice - image_slice).max() < 5e-2
|
||||
|
||||
|
||||
@@ -1347,7 +1347,7 @@ class PipelineTesterMixin:
|
||||
|
||||
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
|
||||
@require_accelerator
|
||||
def test_float16_inference(self, expected_max_diff=6e-2):
|
||||
def test_float16_inference(self, expected_max_diff=5e-2):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
|
||||
@@ -381,7 +381,7 @@ class UnCLIPPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
]
|
||||
|
||||
self._test_inference_batch_single_identical(
|
||||
additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=5e-3
|
||||
additional_params_copy_to_batched_inputs=additional_params_copy_to_batched_inputs, expected_max_diff=9.8e-3
|
||||
)
|
||||
|
||||
def test_inference_batch_consistent(self):
|
||||
|
||||
@@ -17,8 +17,6 @@ import os
|
||||
|
||||
import requests
|
||||
|
||||
from ..src.diffusers.utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
||||
|
||||
|
||||
# Configuration
|
||||
LIBRARY_NAME = "diffusers"
|
||||
@@ -28,7 +26,7 @@ SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
|
||||
|
||||
def check_pypi_for_latest_release(library_name):
|
||||
"""Check PyPI for the latest release of the library."""
|
||||
response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=60)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return data["info"]["version"]
|
||||
@@ -40,7 +38,7 @@ def check_pypi_for_latest_release(library_name):
|
||||
def get_github_release_info(github_repo):
|
||||
"""Fetch the latest release info from GitHub."""
|
||||
url = f"https://api.github.com/repos/{github_repo}/releases/latest"
|
||||
response = requests.get(url, timeout=DIFFUSERS_REQUEST_TIMEOUT)
|
||||
response = requests.get(url, timeout=60)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
|
||||
Reference in New Issue
Block a user