Compare commits

..

9 Commits

Author SHA1 Message Date
Dhruv Nair
bb3bdd3cc2 update 2024-02-15 08:09:09 +00:00
Dhruv Nair
5a2561f51b update 2024-02-15 07:14:20 +00:00
Sayak Paul
4343ce2c8e [Core] Harmonize single file ckpt model loading (#6971)
* use load_model_into_meta in single file utils

* propagate to autoencoder and controlnet.

* correct class name access behaviour.

* remove torch_dtype from load_model_into_meta; seems unncessary

* remove incorrect kwarg

* style to avoid extra unnecessary line breaks
2024-02-14 10:49:06 +05:30
Younes Belkada
0ca7b68198 [PEFT / docs] Add a note about torch.compile (#6864)
* Update using_peft_for_inference.md

* add more explanation
2024-02-14 02:29:29 +01:00
Dhruv Nair
3cf4f9c735 Allow passing config_file argument to ControlNetModel when using from_single_file (#6959)
* update

* update

* update
2024-02-13 18:54:53 +05:30
Dhruv Nair
40dd9cb2bd Move SDXL T2I Adapter lora test into PEFT workflow (#6965)
update
2024-02-13 17:08:53 +05:30
Dhruv Nair
30bcda7de6 Fix flaky IP Adapter test (#6960)
update
2024-02-13 17:07:39 +05:30
YiYi Xu
9ea62d119a [DPMSolverSinglestepScheduler] correct get_order_list for solver_order=2and lower_order_final=True (#6953)
* add

* change default

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
2024-02-12 22:10:33 -10:00
Dhruv Nair
a326d61118 Fix configuring VAE from single file mixin (#6950)
* update
2024-02-12 22:10:05 -10:00
11 changed files with 231 additions and 68 deletions

View File

@@ -165,6 +165,25 @@ list_adapters_component_wise
{"text_encoder": ["toy", "pixel"], "unet": ["toy", "pixel"], "text_encoder_2": ["toy", "pixel"]}
```
## Compatibility with `torch.compile`
If you want to compile your model with `torch.compile` make sure to first fuse the LoRA weights into the base model and unload them.
```py
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipe.set_adapters(["pixel", "toy"], adapter_weights=[0.5, 1.0])
# Fuses the LoRAs into the Unet
pipe.fuse_lora()
pipe.unload_lora_weights()
pipe = torch.compile(pipe)
prompt = "toy_face of a hacker with a hoodie, pixel art"
image = pipe(prompt, num_inference_steps=30, generator=torch.manual_seed(0)).images[0]
```
## Fusing adapters into the model
You can use PEFT to easily fuse/unfuse multiple adapters directly into the model weights (both UNet and text encoder) using the [`~diffusers.loaders.LoraLoaderMixin.fuse_lora`] method, which can lead to a speed-up in inference and lower VRAM usage.

View File

@@ -0,0 +1,62 @@
import argparse
import re
import torch
from safetensors.torch import load_file, save_file
def convert_lora(original_state_dict):
converted_state_dict = {}
for k, v in original_state_dict.items():
if "pos_encoder" in k:
continue
if "alpha" in k:
continue
else:
diffusers_key = (
k.replace(".norms.0", ".norm1")
.replace(".norms.1", ".norm2")
.replace(".ff_norm", ".norm3")
.replace(".attention_blocks.0", ".attn1")
.replace(".attention_blocks.1", ".attn2")
.replace(".temporal_transformer", "")
.replace("lora_unet_", "")
)
diffusers_key = diffusers_key.replace("to_out_0_", "to_out_")
diffusers_key = diffusers_key.replace("mid_block_", "mid_block.")
diffusers_key = diffusers_key.replace("attn1_", "attn1.processor.")
diffusers_key = diffusers_key.replace("attn2_", "attn2.processor.")
diffusers_key = diffusers_key.replace(".lora_", "_lora.")
diffusers_key = re.sub(r'_(\d+)_', r'.\1.', diffusers_key)
converted_state_dict[diffusers_key] = v
return converted_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--output_path", type=str, required=True)
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
state_dict = load_file(args.ckpt_path)
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
converted_state_dict = convert_lora(state_dict)
# convert to new format
output_dict = {}
for module_name, params in converted_state_dict.items():
if type(params) is not torch.Tensor:
continue
output_dict.update({f"unet.{module_name}": params})
save_file(output_dict, f"{args.output_path}/diffusion_pytorch_model.safetensors")

View File

@@ -38,6 +38,9 @@ class FromOriginalVAEMixin:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config_file (`str`, *optional*):
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
@@ -65,6 +68,13 @@ class FromOriginalVAEMixin:
image_size (`int`, *optional*, defaults to 512):
The image size the model was trained on. Use 512 for all Stable Diffusion v1 models and the Stable
Diffusion v2 base model. Use 768 for Stable Diffusion v2.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z
= 1 / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution
Image Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
use_safetensors (`bool`, *optional*, defaults to `None`):
If set to `None`, the safetensors weights are downloaded if they're available **and** if the
safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors
@@ -92,6 +102,7 @@ class FromOriginalVAEMixin:
"""
original_config_file = kwargs.pop("original_config_file", None)
config_file = kwargs.pop("config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -103,6 +114,13 @@ class FromOriginalVAEMixin:
use_safetensors = kwargs.pop("use_safetensors", True)
class_name = cls.__name__
if (config_file is not None) and (original_config_file is not None):
raise ValueError(
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
)
original_config_file = original_config_file or config_file
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,
@@ -118,7 +136,10 @@ class FromOriginalVAEMixin:
)
image_size = kwargs.pop("image_size", None)
component = create_diffusers_vae_model_from_ldm(class_name, original_config, checkpoint, image_size=image_size)
scaling_factor = kwargs.pop("scaling_factor", None)
component = create_diffusers_vae_model_from_ldm(
class_name, original_config, checkpoint, image_size=image_size, scaling_factor=scaling_factor
)
vae = component["vae"]
if torch_dtype is not None:
vae = vae.to(torch_dtype)

View File

@@ -38,6 +38,9 @@ class FromOriginalControlNetMixin:
- A link to the `.ckpt` file (for example
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
- A path to a *file* containing all pipeline weights.
config_file (`str`, *optional*):
Filepath to the configuration YAML file associated with the model. If not provided it will default to:
https://raw.githubusercontent.com/lllyasviel/ControlNet/main/models/cldm_v15.yaml
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
dtype is automatically derived from the model's weights.
@@ -89,6 +92,7 @@ class FromOriginalControlNetMixin:
```
"""
original_config_file = kwargs.pop("original_config_file", None)
config_file = kwargs.pop("config_file", None)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
@@ -100,6 +104,12 @@ class FromOriginalControlNetMixin:
use_safetensors = kwargs.pop("use_safetensors", True)
class_name = cls.__name__
if (config_file is not None) and (original_config_file is not None):
raise ValueError(
"You cannot pass both `config_file` and `original_config_file` to `from_single_file`. Please use only one of these arguments."
)
original_config_file = config_file or original_config_file
original_config, checkpoint = fetch_ldm_config_and_checkpoint(
pretrained_model_link_or_path=pretrained_model_link_or_path,
class_name=class_name,

View File

@@ -48,6 +48,7 @@ def build_sub_model_components(
load_safety_checker=False,
model_type=None,
image_size=None,
torch_dtype=None,
**kwargs,
):
if component_name in pipeline_components:
@@ -96,7 +97,7 @@ def build_sub_model_components(
from ..pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only, torch_dtype=torch_dtype
)
else:
safety_checker = None

View File

@@ -48,7 +48,6 @@ if is_transformers_available():
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -175,6 +174,7 @@ DIFFUSERS_TO_LDM_MAPPING = {
}
LDM_VAE_KEY = "first_stage_model."
LDM_VAE_DEFAULT_SCALING_FACTOR = 0.18215
LDM_UNET_KEY = "model.diffusion_model."
LDM_CONTROLNET_KEY = "control_model."
LDM_CLIP_PREFIX_TO_REMOVE = ["cond_stage_model.transformer.", "conditioner.embedders.0.transformer."]
@@ -518,7 +518,10 @@ def create_vae_diffusers_config(original_config, image_size, scaling_factor=None
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config["model"]["params"]["first_stage_config"]["params"]["ddconfig"]
scaling_factor = scaling_factor or original_config["model"]["params"]["scale_factor"]
if scaling_factor is None and "scale_factor" in original_config["model"]["params"]:
scaling_factor = original_config["model"]["params"]["scale_factor"]
elif scaling_factor is None:
scaling_factor = LDM_VAE_DEFAULT_SCALING_FACTOR
block_out_channels = [vae_params["ch"] * mult for mult in vae_params["ch_mult"]]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
@@ -870,8 +873,17 @@ def create_diffusers_controlnet_model_from_ldm(
controlnet = ControlNetModel(**diffusers_config)
if is_accelerate_available():
for param_name, param in diffusers_format_controlnet_checkpoint.items():
set_module_tensor_to_device(controlnet, param_name, "cpu", value=param)
from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(controlnet, diffusers_format_controlnet_checkpoint)
if controlnet._keys_to_ignore_on_load_unexpected is not None:
for pat in controlnet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {controlnet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
controlnet.load_state_dict(diffusers_format_controlnet_checkpoint)
@@ -1034,8 +1046,17 @@ def create_text_encoder_from_ldm_clip_checkpoint(config_name, checkpoint, local_
text_model_dict[diffusers_key] = checkpoint[key]
if is_accelerate_available():
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict)
if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
text_model_dict.pop("text_model.embeddings.position_ids", None)
@@ -1116,8 +1137,17 @@ def create_text_encoder_from_open_clip_checkpoint(
text_model_dict[diffusers_key] = checkpoint[key]
if is_accelerate_available():
for param_name, param in text_model_dict.items():
set_module_tensor_to_device(text_model, param_name, "cpu", value=param)
from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(text_model, text_model_dict)
if text_model._keys_to_ignore_on_load_unexpected is not None:
for pat in text_model._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {text_model.__class__.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
@@ -1164,8 +1194,17 @@ def create_diffusers_unet_model_from_ldm(
unet = UNet2DConditionModel(**unet_config)
if is_accelerate_available():
for param_name, param in diffusers_format_unet_checkpoint.items():
set_module_tensor_to_device(unet, param_name, "cpu", value=param)
from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(unet, diffusers_format_unet_checkpoint)
if unet._keys_to_ignore_on_load_unexpected is not None:
for pat in unet._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {unet.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
unet.load_state_dict(diffusers_format_unet_checkpoint)
@@ -1173,7 +1212,7 @@ def create_diffusers_unet_model_from_ldm(
def create_diffusers_vae_model_from_ldm(
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=0.18125
pipeline_class_name, original_config, checkpoint, image_size=None, scaling_factor=None
):
# import here to avoid circular imports
from ..models import AutoencoderKL
@@ -1188,8 +1227,17 @@ def create_diffusers_vae_model_from_ldm(
vae = AutoencoderKL(**vae_config)
if is_accelerate_available():
for param_name, param in diffusers_format_vae_checkpoint.items():
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
from ..models.modeling_utils import load_model_dict_into_meta
unexpected_keys = load_model_dict_into_meta(vae, diffusers_format_vae_checkpoint)
if vae._keys_to_ignore_on_load_unexpected is not None:
for pat in vae._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if len(unexpected_keys) > 0:
logger.warn(
f"Some weights of the model checkpoint were not used when initializing {vae.__name__}: \n {[', '.join(unexpected_keys)]}"
)
else:
vae.load_state_dict(diffusers_format_vae_checkpoint)
@@ -1226,7 +1274,9 @@ def create_text_encoders_and_tokenizers_from_ldm(
try:
config_name = "openai/clip-vit-large-patch14"
text_encoder = create_text_encoder_from_ldm_clip_checkpoint(
config_name, checkpoint, local_files_only=local_files_only
config_name,
checkpoint,
local_files_only=local_files_only,
)
tokenizer = CLIPTokenizer.from_pretrained(config_name, local_files_only=local_files_only)

View File

@@ -151,7 +151,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
sample_max_value: float = 1.0,
algorithm_type: str = "dpmsolver++",
solver_type: str = "midpoint",
lower_order_final: bool = True,
lower_order_final: bool = False,
use_karras_sigmas: Optional[bool] = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
lambda_min_clipped: float = -float("inf"),
@@ -233,7 +233,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
orders = [1, 2, 3] * (steps // 3) + [1, 2]
elif order == 2:
if steps % 2 == 0:
orders = [1, 2] * (steps // 2)
orders = [1, 2] * (steps // 2 - 1) + [1, 1]
else:
orders = [1, 2] * (steps // 2) + [1]
elif order == 1:
@@ -320,7 +320,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if not self.config.lower_order_final and num_inference_steps % self.config.solver_order != 0:
logger.warn(
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=True`."
"Changing scheduler {self.config} to have `lower_order_final` set to True to handle uneven amount of inference steps. Please make sure to always use an even number of `num_inference steps when using `lower_order_final=False`."
)
self.register_to_config(lower_order_final=True)

View File

@@ -37,8 +37,10 @@ from diffusers import (
EulerDiscreteScheduler,
LCMScheduler,
StableDiffusionPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetPipeline,
StableDiffusionXLPipeline,
T2IAdapter,
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
@@ -2175,7 +2177,7 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipeline)
def test_canny_lora(self):
def test_controlnet_canny_lora(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
@@ -2199,6 +2201,34 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
assert np.allclose(original_image, expected_image, atol=1e-04)
release_memory(pipe)
def test_sdxl_t2i_adapter_canny_lora(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16).to(
"cpu"
)
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
adapter=adapter,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
assert images[0].shape == (768, 512, 3)
image_slice = images[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
@nightly
def test_sequential_fuse_unfuse(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)

View File

@@ -35,6 +35,7 @@ from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0
from diffusers.utils import load_image
from diffusers.utils.testing_utils import (
enable_full_determinism,
is_flaky,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
@@ -259,6 +260,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
]
assert processors == [True] * len(processors)
@is_flaky
def test_multi(self):
image_encoder = self.get_image_encoder(repo_id="h94/IP-Adapter", subfolder="models/image_encoder")
pipeline = StableDiffusionPipeline.from_pretrained(
@@ -275,7 +277,7 @@ class IPAdapterSDIntegrationTests(IPAdapterNightlyTestsMixin):
inputs["ip_adapter_image"] = [ip_adapter_image, [ip_adapter_image] * 2]
images = pipeline(**inputs).images
image_slice = images[0, :3, :3, -1].flatten()
expected_slice = np.array([0.1704, 0.1296, 0.1272, 0.2212, 0.1514, 0.1479, 0.4172, 0.4263, 0.4360])
expected_slice = np.array([0.5234, 0.5352, 0.5625, 0.5713, 0.5947, 0.6206, 0.5786, 0.6187, 0.6494])
max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice)
assert max_diff < 5e-4

View File

@@ -27,13 +27,7 @@ from diffusers import (
PixArtAlphaPipeline,
Transformer2DModel,
)
from diffusers.utils.testing_utils import (
enable_full_determinism,
numpy_cosine_similarity_distance,
require_torch_gpu,
slow,
torch_device,
)
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
@@ -338,35 +332,37 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
torch.cuda.empty_cache()
def test_pixart_1024(self):
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_1024, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = self.prompt
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
image = pipe(prompt, generator=generator, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589])
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
self.assertLessEqual(max_diff, 1e-4)
expected_slice = np.array([0.1941, 0.2117, 0.2188, 0.1946, 0.218, 0.2124, 0.199, 0.2437, 0.2583])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_pixart_512(self):
generator = torch.Generator("cpu").manual_seed(0)
generator = torch.manual_seed(0)
pipe = PixArtAlphaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
pipe.enable_model_cpu_offload()
prompt = self.prompt
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
image = pipe(prompt, generator=generator, output_type="np").images
image_slice = image[0, -3:, -3:, -1]
expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958])
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
self.assertLessEqual(max_diff, 1e-4)
expected_slice = np.array([0.2637, 0.291, 0.2939, 0.207, 0.2512, 0.2783, 0.2168, 0.2324, 0.2817])
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_pixart_1024_without_resolution_binning(self):
generator = torch.manual_seed(0)
@@ -376,7 +372,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
prompt = self.prompt
height, width = 1024, 768
num_inference_steps = 2
num_inference_steps = 10
image = pipe(
prompt,
@@ -410,7 +406,7 @@ class PixArtAlphaPipelineIntegrationTests(unittest.TestCase):
prompt = self.prompt
height, width = 512, 768
num_inference_steps = 2
num_inference_steps = 10
image = pipe(
prompt,

View File

@@ -672,34 +672,6 @@ class AdapterSDXLPipelineSlowTests(unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
def test_canny_lora(self):
adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16).to(
"cpu"
)
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
adapter=adapter,
torch_dtype=torch.float16,
variant="fp16",
)
pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors")
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
prompt = "toy"
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/t2i_adapter/toy_canny.png"
)
images = pipe(prompt, image=image, generator=generator, output_type="np", num_inference_steps=3).images
assert images[0].shape == (768, 512, 3)
image_slice = images[0, -3:, -3:, -1].flatten()
expected_slice = np.array([0.4284, 0.4337, 0.4319, 0.4255, 0.4329, 0.4280, 0.4338, 0.4420, 0.4226])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = (
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0.safetensors"