Compare commits

..

8 Commits

Author SHA1 Message Date
Dhruv Nair
97ee35f826 update 2026-04-16 11:52:19 +02:00
Dhruv Nair
80ad468e2b update 2026-04-16 09:40:38 +02:00
Dhruv Nair
fe38d77603 update 2026-04-15 11:48:47 +02:00
Alexey Zolotenkov
526498d219 Fix Qwen Image DreamBooth prior-preservation batch ordering (#13441)
Fix Qwen Image DreamBooth prior-preservation batching

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-14 18:00:37 +05:30
HsiaWinter
6a339ce637 fix some dtype issue for gguf / some gpu backends (#13464) 2026-04-13 18:41:01 -10:00
Juan Acevedo
26bb7fa0cb [ptxla] fix pytorch xla inference on TPUs. (#13463)
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
2026-04-14 09:21:26 +05:30
hlky
5063aa5566 FlashPack (#12700)
* FlashPack

* setup

* save_pretrained

* dtype is property

* destination_path

* logging

* pipeline

* ruff

* flashpack_kwargs

* download

* Fix docstring

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* tests

* ignore_cleanup_errors

* -load_flashpack_checkpoint

* Apply style fixes

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-04-13 20:42:17 +05:30
Sayak Paul
62b1071609 [core] fix fa4 integration (#13443)
fix fa4 integration
2026-04-13 11:28:11 +05:30
30 changed files with 1217 additions and 1018 deletions

View File

@@ -1533,9 +1533,9 @@ def main(args):
# from the cat above, but collate_fn also doubles the prompts list. Use half the
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
if prompt_embeds_mask is not None:
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_repeat_elements, dim=0)
# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].sample()
@@ -1602,10 +1602,11 @@ def main(args):
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
# Compute prior loss
prior_loss = torch.mean(
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
target_prior.shape[0], -1
),
1,

View File

@@ -146,6 +146,7 @@ _deps = [
"phonemizer",
"opencv-python",
"timm",
"flashpack",
]
# this is a lookup table with items like:
@@ -250,6 +251,7 @@ extras["gguf"] = deps_list("gguf", "accelerate")
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
extras["torchao"] = deps_list("torchao", "accelerate")
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
extras["flashpack"] = deps_list("flashpack")
if os.name == "nt": # windows
extras["flax"] = [] # jax is not supported on windows

View File

@@ -53,4 +53,5 @@ deps = {
"phonemizer": "phonemizer",
"opencv-python": "opencv-python",
"timm": "timm",
"flashpack": "flashpack",
}

View File

@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)

View File

@@ -42,6 +42,7 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import (
CONFIG_NAME,
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
HF_ENABLE_PARALLEL_LOADING,
SAFE_WEIGHTS_INDEX_NAME,
@@ -55,6 +56,7 @@ from ..utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_version,
is_flashpack_available,
is_peft_available,
is_torch_version,
logging,
@@ -673,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
variant: str | None = None,
max_shard_size: int | str = "10GB",
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -725,7 +728,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
" the logger on the traceback to understand the reason why the quantized model is not serializable."
)
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
weights_name = WEIGHTS_NAME
if use_flashpack:
weights_name = FLASHPACK_WEIGHTS_NAME
elif safe_serialization:
weights_name = SAFETENSORS_WEIGHTS_NAME
weights_name = _add_variant(weights_name, variant)
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
@@ -752,58 +760,74 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
# Save the model
state_dict = model_to_save.state_dict()
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
if use_flashpack:
if is_flashpack_available():
import flashpack
else:
torch.save(shard, filepath)
logger.error(
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch.")
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
flashpack.serialization.pack_to_file(
state_dict_or_model=state_dict,
destination_path=os.path.join(save_directory, weights_name),
target_dtype=self.dtype,
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
# Save the model
state_dict_split = split_torch_state_dict_into_shards(
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
)
# Clean the folder from a previous save
if is_main_process:
for filename in os.listdir(save_directory):
if filename in state_dict_split.filename_to_tensors.keys():
continue
full_filename = os.path.join(save_directory, filename)
if not os.path.isfile(full_filename):
continue
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
weights_without_ext = weights_without_ext.replace("{suffix}", "")
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
if (
filename.startswith(weights_without_ext)
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
):
os.remove(full_filename)
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
filepath = os.path.join(save_directory, filename)
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
else:
torch.save(shard, filepath)
if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logger.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
else:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
if push_to_hub:
# Create a new empty model card and eventually tag it
@@ -940,6 +964,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
disable_mmap ('bool', *optional*, defaults to 'False'):
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, the model is loaded from `flashpack` weights.
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
Kwargs passed to
[`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
with `hf > auth login`. You can also activate the special >
@@ -984,6 +1014,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
use_flashpack = kwargs.pop("use_flashpack", False)
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
if is_parallel_loading_enabled and not low_cpu_mem_usage:
@@ -1212,30 +1244,37 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
subfolder=subfolder or "",
dduf_entries=dduf_entries,
)
elif use_safetensors:
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
else:
if use_flashpack:
weights_name = FLASHPACK_WEIGHTS_NAME
elif use_safetensors:
weights_name = _add_variant(SAFETENSORS_WEIGHTS_NAME, variant)
else:
weights_name = None
if weights_name is not None:
try:
resolved_model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
except IOError as e:
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
if not allow_pickle:
raise
logger.warning(
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
)
if resolved_model_file is None and not is_sharded:
resolved_model_file = _get_model_file(
@@ -1275,6 +1314,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
with ContextManagers(init_contexts):
model = cls.from_config(config, **unused_kwargs)
if use_flashpack:
if is_flashpack_available():
import flashpack
else:
logger.error(
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
)
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
if device_map is None:
logger.warning(
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
"the benefit of FlashPack."
)
flashpack_device = torch.device("cpu")
else:
device = device_map[""]
if isinstance(device, str) and device in ["auto", "balanced", "balanced_low_0", "sequential"]:
raise ValueError(
"FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
)
flashpack_device = torch.device(device) if not isinstance(device, torch.device) else device
flashpack.mixin.assign_from_file(
model=model,
path=resolved_model_file[0],
device=flashpack_device,
**flashpack_kwargs,
)
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)
if output_loading_info:
logger.warning("`output_loading_info` is not supported with FlashPack.")
return model, {}
return model
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

View File

@@ -44,7 +44,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput):
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
return out.float()
@@ -400,8 +400,8 @@ class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
]
# AdaLN
sample = self.time_proj(timestep.to(dtype))
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
sample = self.time_proj(timestep)
sample = sample.to(dtype=dtype)
c = self.time_embedding(sample)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)

View File

@@ -445,10 +445,14 @@ class WanAnimateFaceBlockAttnProcessor:
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
B, T, N, C = encoder_hidden_states.shape
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
value = value.view(B, T, N, attn.heads, -1)
query = attn.norm_q(query)

View File

@@ -331,7 +331,7 @@ class WanVACETransformer3DModel(
)
if i in self.config.vace_layers:
control_hint, scale = control_hidden_states_list.pop()
hidden_states = hidden_states + control_hint * scale
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
else:
# Prepare VACE hints
control_hidden_states_list = []
@@ -346,7 +346,7 @@ class WanVACETransformer3DModel(
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
if i in self.config.vace_layers:
control_hint, scale = control_hidden_states_list.pop()
hidden_states = hidden_states + control_hint * scale
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
# 6. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)

View File

@@ -877,10 +877,7 @@ class FluxPipeline(
self.scheduler.config.get("max_shift", 1.15),
)
if XLA_AVAILABLE:
timestep_device = "cpu"
else:
timestep_device = device
timestep_device = device
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,

View File

@@ -28,6 +28,7 @@ from packaging import version
from .. import __version__
from ..utils import (
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
@@ -194,6 +195,7 @@ def filter_model_files(filenames):
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
FLASHPACK_WEIGHTS_NAME,
]
if is_transformers_available():
@@ -413,6 +415,9 @@ def get_class_obj_and_candidates(
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
if class_name.startswith("FlashPack"):
class_name = class_name.removeprefix("FlashPack")
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
@@ -760,6 +765,7 @@ def load_sub_model(
provider_options: Any,
disable_mmap: bool,
quantization_config: Any | None = None,
use_flashpack: bool = False,
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""
from ..quantizers import PipelineQuantizationConfig
@@ -838,6 +844,9 @@ def load_sub_model(
loading_kwargs["variant"] = model_variants.pop(name, None)
loading_kwargs["use_safetensors"] = use_safetensors
if is_diffusers_model:
loading_kwargs["use_flashpack"] = use_flashpack
if from_flax:
loading_kwargs["from_flax"] = True
@@ -887,7 +896,7 @@ def load_sub_model(
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
# remove hooks
remove_hook_from_module(loaded_sub_model, recurse=True)
needs_offloading_to_cpu = device_map[""] == "cpu"
@@ -1093,6 +1102,7 @@ def _get_ignore_patterns(
allow_pickle: bool,
use_onnx: bool,
is_onnx: bool,
use_flashpack: bool,
variant: str | None = None,
) -> list[str]:
if (
@@ -1118,6 +1128,9 @@ def _get_ignore_patterns(
if not use_onnx:
ignore_patterns += ["*.onnx", "*.pb"]
elif use_flashpack:
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
else:
ignore_patterns = ["*.safetensors", "*.msgpack"]

View File

@@ -244,6 +244,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
variant: str | None = None,
max_shard_size: int | str | None = None,
push_to_hub: bool = False,
use_flashpack: bool = False,
**kwargs,
):
"""
@@ -341,6 +342,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
save_kwargs = {}
@@ -351,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if save_method_accept_max_shard_size and max_shard_size is not None:
# max_shard_size is expected to not be None in ModelMixin
save_kwargs["max_shard_size"] = max_shard_size
if save_method_accept_flashpack:
save_kwargs["use_flashpack"] = use_flashpack
if save_method_accept_peft_format:
# Set save_peft_format=False for transformers>=5.0.0 compatibility
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
@@ -781,6 +785,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
quantization_config = kwargs.pop("quantization_config", None)
use_flashpack = kwargs.pop("use_flashpack", False)
disable_mmap = kwargs.pop("disable_mmap", False)
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
@@ -1071,6 +1076,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
provider_options=provider_options,
disable_mmap=disable_mmap,
quantization_config=quantization_config,
use_flashpack=use_flashpack,
)
logger.info(
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
@@ -1576,6 +1582,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
option should only be set to `True` for repositories you trust and in which you have read the code, as
it will execute code present on the Hub on your local machine.
use_flashpack (`bool`, *optional*, defaults to `False`):
If set to `True`, FlashPack weights will always be downloaded if present. If set to `False`, FlashPack
weights will never be downloaded.
Returns:
`os.PathLike`:
@@ -1600,6 +1609,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
trust_remote_code = kwargs.pop("trust_remote_code", False)
dduf_file: dict[str, DDUFEntry] | None = kwargs.pop("dduf_file", None)
use_flashpack = kwargs.pop("use_flashpack", False)
if dduf_file:
if custom_pipeline:
@@ -1719,6 +1729,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
allow_pickle,
use_onnx,
pipeline_class._is_onnx,
use_flashpack,
variant,
)

View File

@@ -24,6 +24,8 @@ from .constants import (
DEPRECATED_REVISION_ARGS,
DIFFUSERS_DYNAMIC_MODULE_NAME,
DIFFUSERS_LOAD_ID_FIELDS,
FLASHPACK_FILE_EXTENSION,
FLASHPACK_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
GGUF_FILE_EXTENSION,
HF_ENABLE_PARALLEL_LOADING,
@@ -76,6 +78,7 @@ from .import_utils import (
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_flashpack_available,
is_flax_available,
is_ftfy_available,
is_gguf_available,

View File

@@ -34,6 +34,8 @@ ONNX_WEIGHTS_NAME = "model.onnx"
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
SAFETENSORS_FILE_EXTENSION = "safetensors"
FLASHPACK_WEIGHTS_NAME = "model.flashpack"
FLASHPACK_FILE_EXTENSION = "flashpack"
GGUF_FILE_EXTENSION = "gguf"
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")

View File

@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
_av_available, _av_version = _is_package_available("av")
@@ -361,6 +362,10 @@ def is_gguf_available():
return _gguf_available
def is_flashpack_available():
return _flashpack_available
def is_torchao_available():
return _torchao_available

View File

@@ -465,8 +465,7 @@ class UNetTesterMixin:
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
@@ -481,9 +480,9 @@ class UNetTesterMixin:
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
class ModelTesterMixin:

View File

@@ -205,6 +205,11 @@ class BaseModelTesterConfig:
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
return {}
@property
def torch_dtype(self) -> torch.dtype:
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
return torch.float32
@property
def output_shape(self) -> Optional[tuple]:
"""Expected output shape for output validation tests."""
@@ -287,9 +292,8 @@ class ModelTesterMixin:
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
)
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -309,9 +313,8 @@ class ModelTesterMixin:
new_model.to(torch_device)
inputs_dict = self.get_dummy_inputs()
image = model(**inputs_dict, return_dict=False)[0]
new_image = new_model(**inputs_dict, return_dict=False)[0]
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
@@ -339,9 +342,8 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
first = model(**inputs_dict, return_dict=False)[0]
second = model(**inputs_dict, return_dict=False)[0]
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
first_flat = first.flatten()
second_flat = second.flatten()
@@ -398,9 +400,8 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
outputs_dict = model(**self.get_dummy_inputs())
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
@@ -527,10 +528,8 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
@@ -569,10 +568,8 @@ class ModelTesterMixin:
new_model = new_model.to(torch_device)
torch.manual_seed(0)
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
new_output = new_model(**inputs_dict, return_dict=False)[0]
inputs_dict_new = self.get_dummy_inputs()
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
assert_tensors_close(
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
@@ -622,10 +619,8 @@ class ModelTesterMixin:
model_parallel = model_parallel.to(torch_device)
torch.manual_seed(0)
# Re-create inputs only if they contain a generator (which needs to be reset)
if "generator" in inputs_dict:
inputs_dict = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
inputs_dict_parallel = self.get_dummy_inputs()
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
assert_tensors_close(
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"

View File

@@ -92,6 +92,9 @@ class TorchCompileTesterMixin:
model.eval()
model.compile_repeated_blocks(fullgraph=True)
if self.model_class.__name__ == "UNet2DConditionModel":
recompile_limit = 2
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(recompile_limit=recompile_limit),

View File

@@ -359,15 +359,7 @@ class QuantizationTesterMixin:
if isinstance(module, torch.nn.Linear):
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
# Get model dtype from first parameter
model_dtype = next(model.parameters()).dtype
inputs = self.get_dummy_inputs()
# Cast inputs to model dtype
inputs = {
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs.items()
}
output = model(**inputs, return_dict=False)[0]
assert output is not None, "Model output is None after dequantization"
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
@@ -575,33 +567,28 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
@torch.no_grad()
def test_bnb_keep_modules_in_fp32(self):
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
if not fp32_modules:
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
self.model_class._keep_in_fp32_modules = ["proj_out"]
model = self._create_quantized_model(config_kwargs)
model.to(torch_device)
try:
model = self._create_quantized_model(config_kwargs)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
assert module.weight.dtype == torch.float32, (
f"Module {name} should be FP32 but is {module.weight.dtype}"
)
else:
assert module.weight.dtype == torch.uint8, (
f"Module {name} should be uint8 but is {module.weight.dtype}"
)
inputs = self.get_dummy_inputs()
_ = model(**inputs)
finally:
if original_fp32_modules is not None:
self.model_class._keep_in_fp32_modules = original_fp32_modules
inputs = self.get_dummy_inputs()
_ = model(**inputs)
def test_bnb_modules_to_not_convert(self):
"""Test that modules_to_not_convert parameter works correctly."""

View File

@@ -159,21 +159,36 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
return {
"hidden_states": randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
(batch_size, height * width, num_latent_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
(batch_size, sequence_length, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"pooled_projections": randn_tensor(
(batch_size, embedding_dim), generator=self.generator, device=torch_device
(batch_size, embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"img_ids": randn_tensor(
(height * width, num_image_channels), generator=self.generator, device=torch_device
(height * width, num_image_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"txt_ids": randn_tensor(
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
(sequence_length, num_image_channels),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size),
}
@@ -320,6 +335,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux Transformer."""
@property
def torch_dtype(self):
return torch.float16
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
"""Quanto quantization tests for Flux Transformer."""

View File

@@ -91,11 +91,13 @@ class WanTransformer3DTesterConfig(BaseModelTesterConfig):
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}

View File

@@ -113,27 +113,32 @@ class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"encoder_hidden_states_image": randn_tensor(
(batch_size, clip_seq_len, clip_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"pose_hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"face_pixel_values": randn_tensor(
(batch_size, 3, inference_segment_length, face_height, face_width),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
}

View File

@@ -96,16 +96,19 @@ class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"control_hidden_states": randn_tensor(
(batch_size, vace_in_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
dtype=self.torch_dtype,
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
@@ -24,39 +26,64 @@ from ...testing_utils import (
slow,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
_LAYERWISE_CASTING_XFAIL_REASON = (
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
)
class UNet1DTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (standard variant)."""
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet1DModel
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (14, 16)
return (4, 14, 16)
@property
def main_input_name(self):
return "sample"
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
def get_init_dict(self):
return {
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def test_determinism(self):
super().test_determinism()
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
super().test_output()
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (8, 8, 16, 16),
"in_channels": 14,
"out_channels": 14,
@@ -70,40 +97,18 @@ class UNet1DTesterConfig(BaseModelTesterConfig):
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
"act_fn": "swish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DHubLoading(UNet1DTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.get_dummy_inputs())
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@@ -126,7 +131,12 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
# fmt: off
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
# fmt: on
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@slow
def test_unet_1d_maestro(self):
@@ -147,29 +157,98 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
assert (output_sum - 224.0896).abs() < 0.5
assert (output_max - 0.0607).abs() < 4e-4
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
# =============================================================================
# UNet1D RL (Value Function) Model Tests
# =============================================================================
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass
class UNet1DRLTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet1DModel testing (RL value function variant)."""
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet1DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet1DModel
def dummy_input(self):
batch_size = 4
num_features = 14
seq_len = 16
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
time_step = torch.tensor([10] * batch_size).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 14, 16)
@property
def output_shape(self):
return (1,)
return (4, 14, 1)
@property
def main_input_name(self):
return "sample"
def test_determinism(self):
super().test_determinism()
def get_init_dict(self):
return {
def test_outputs_equivalence(self):
super().test_outputs_equivalence()
def test_from_save_pretrained(self):
super().test_from_save_pretrained()
def test_from_save_pretrained_variant(self):
super().test_from_save_pretrained_variant()
def test_model_from_pretrained(self):
super().test_model_from_pretrained()
def test_output(self):
# UNetRL is a value-function is different output shape
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
@unittest.skip("Test not supported.")
def test_ema_training(self):
pass
@unittest.skip("Test not supported.")
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
"out_channels": 14,
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
@@ -185,54 +264,18 @@ class UNet1DRLTesterConfig(BaseModelTesterConfig):
"time_embedding_type": "positional",
"act_fn": "mish",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_features = 14
seq_len = 16
return {
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
"timestep": torch.tensor([10] * batch_size).to(torch_device),
}
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Not implemented yet for this UNet")
def test_forward_with_norm_groups(self):
pass
@torch.no_grad()
def test_output(self):
# UNetRL is a value-function with different output shape (batch, 1)
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
output = model(**inputs_dict, return_dict=False)[0]
assert output is not None
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
def test_from_pretrained_hub(self):
value_function, vf_loading_info = UNet1DModel.from_pretrained(
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
)
assert value_function is not None
assert len(vf_loading_info["missing_keys"]) == 0
self.assertIsNotNone(value_function)
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
value_function.to(torch_device)
image = value_function(**self.get_dummy_inputs())
image = value_function(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@@ -256,4 +299,31 @@ class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
# fmt: off
expected_output_slice = torch.tensor([165.25] * seq_len)
# fmt: on
assert torch.allclose(output, expected_output_slice, rtol=1e-3)
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# Not implemented yet for this UNet
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_inference(self):
pass
@pytest.mark.xfail(
reason=(
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
"2. Unskip this test."
),
)
def test_layerwise_casting_memory(self):
pass

View File

@@ -15,11 +15,12 @@
import gc
import math
import unittest
import pytest
import torch
from diffusers import UNet2DModel
from diffusers.utils import logging
from ...testing_utils import (
backend_empty_cache,
@@ -30,40 +31,39 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
# =============================================================================
# Standard UNet2D Model Tests
# =============================================================================
class UNet2DTesterConfig(BaseModelTesterConfig):
"""Base configuration for standard UNet2DModel testing."""
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet2DModel
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 2,
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
@@ -74,22 +74,11 @@ class UNet2DTesterConfig(BaseModelTesterConfig):
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_mid_block_attn_groups(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["add_attention"] = True
init_dict["attn_norm_num_groups"] = 4
@@ -98,11 +87,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
model.to(torch_device)
model.eval()
assert model.mid_block.attentions[0].group_norm is not None, (
"Mid block Attention group norm should exist but does not."
self.assertIsNotNone(
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
)
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
"Mid block Attention group norm does not have the expected number of groups."
self.assertEqual(
model.mid_block.attentions[0].group_norm.num_groups,
init_dict["attn_norm_num_groups"],
"Mid block Attention group norm does not have the expected number of groups.",
)
with torch.no_grad():
@@ -111,15 +102,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_mid_block_none(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
mid_none_init_dict = self.get_init_dict()
mid_none_inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
mid_none_init_dict["mid_block_type"] = None
model = self.model_class(**init_dict)
@@ -130,7 +119,7 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
mid_none_model.to(torch_device)
mid_none_model.eval()
assert mid_none_model.mid_block is None, "Mid block should not exist."
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
with torch.no_grad():
output = model(**inputs_dict)
@@ -144,10 +133,8 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
if isinstance(mid_none_output, dict):
mid_none_output = mid_none_output.to_tuple()[0]
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"AttnUpBlock2D",
@@ -156,32 +143,41 @@ class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
"UpBlock2D",
"DownBlock2D",
}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
attention_head_dim = 8
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
# =============================================================================
# UNet2D LDM Model Tests
# =============================================================================
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel LDM variant testing."""
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet2DModel
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"in_channels": 4,
"out_channels": 4,
@@ -191,34 +187,17 @@ class UNet2DLDMTesterConfig(BaseModelTesterConfig):
"down_block_types": ("DownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "UpBlock2D"),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
}
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.get_dummy_inputs()).sample
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
@@ -226,7 +205,7 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
def test_from_pretrained_accelerate(self):
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
model.to(torch_device)
image = model(**self.get_dummy_inputs()).sample
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
@@ -286,31 +265,44 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
# fmt: on
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
attention_head_dim = 32
block_out_channels = (32, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
# =============================================================================
# NCSN++ Model Tests
# =============================================================================
class NCSNppTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DModel NCSN++ variant testing."""
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DModel
main_input_name = "sample"
@property
def model_class(self):
return UNet2DModel
def dummy_input(self, sizes=(32, 32)):
batch_size = 4
num_channels = 3
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
return {"sample": noise, "timestep": time_step}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": [32, 64, 64, 64],
"in_channels": 3,
"layers_per_block": 1,
@@ -332,71 +324,17 @@ class NCSNppTesterConfig(BaseModelTesterConfig):
"SkipUpBlock2D",
],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
}
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_keep_in_fp32_modules(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_from_save_pretrained_dtype_inference(self):
pass
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass
@pytest.mark.skip(
"To make layerwise casting work with this model, we will have to update the implementation. "
"Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_training(self):
pass
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestNCSNppHubLoading(NCSNppTesterConfig):
@slow
def test_from_pretrained_hub(self):
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
inputs = self.get_dummy_inputs()
inputs = self.dummy_input
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
inputs["sample"] = noise
image = model(**inputs)
@@ -423,7 +361,7 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
# fmt: on
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
def test_output_pretrained_ve_large(self):
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
@@ -444,4 +382,35 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
# fmt: on
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# not required for this model
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"UNetMidBlock2D",
}
block_out_channels = (32, 64, 64, 64)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, block_out_channels=block_out_channels
)
def test_effective_gradient_checkpointing(self):
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_inference(self):
pass
@unittest.skip(
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
)
def test_layerwise_casting_memory(self):
pass

View File

@@ -20,7 +20,6 @@ import tempfile
import unittest
from collections import OrderedDict
import pytest
import torch
from huggingface_hub import snapshot_download
from parameterized import parameterized
@@ -53,24 +52,17 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
IPAdapterTesterMixin,
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
UNetTesterMixin,
)
if is_peft_available():
from peft import LoraConfig
from ..testing_utils.lora import check_if_lora_correctly_set
from peft.tuners.tuners_utils import BaseTunerLayer
logger = logging.get_logger(__name__)
@@ -90,6 +82,16 @@ def get_unet_lora_config():
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
@@ -352,28 +354,34 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
return custom_diffusion_attn_procs
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet2DConditionModel testing."""
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
# We override the items here because the unet under consideration is small.
model_split_percents = [0.5, 0.34, 0.4]
@property
def model_class(self):
return UNet2DConditionModel
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def output_shape(self) -> tuple[int, int, int]:
def input_shape(self):
return (4, 16, 16)
@property
def model_split_percents(self) -> list[float]:
return [0.5, 0.34, 0.4]
def output_shape(self):
return (4, 16, 16)
@property
def main_input_name(self) -> str:
return "sample"
def get_init_dict(self) -> dict:
"""Return UNet2D model initialization arguments."""
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
@@ -385,24 +393,26 @@ class UNet2DConditionTesterConfig(BaseModelTesterConfig):
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
"""Return dummy inputs for UNet2D model."""
batch_size = 4
num_channels = 4
sizes = (16, 16)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
def test_model_with_attention_head_dim_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -417,13 +427,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_use_linear_projection(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["use_linear_projection"] = True
@@ -437,13 +446,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (8, 8)
@@ -457,13 +465,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_simple_projection(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -482,13 +489,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_class_embeddings_concat(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
batch_size, _, _, sample_size = inputs_dict["sample"].shape
@@ -508,287 +514,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
def test_pickle(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample = model(**inputs_dict).sample
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
assert output.shape == expected_shape, "Input and output shapes do not match"
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
"""Hub checkpoint loading tests for UNet2DConditionModel."""
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
inputs_dict = self.get_dummy_inputs()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
inputs_dict = self.get_dummy_inputs()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for UNet2DConditionModel."""
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
"""Test that deprecated load_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
"""Test that deprecated save_attn_procs method raises FutureWarning."""
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
model.save_attn_procs(os.path.join(tmpdirname))
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for UNet2DConditionModel."""
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
"""Training tests for UNet2DConditionModel."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
"""Attention processor tests for UNet2DConditionModel."""
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -813,7 +544,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
assert output is not None
def test_model_sliceable_head_dim(self):
init_dict = self.get_init_dict()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -831,6 +562,21 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
for module in model.children():
check_sliceable_dim_attr(module)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CrossAttnUpBlock2D",
"CrossAttnDownBlock2D",
"UNetMidBlock2DCrossAttn",
"UpBlock2D",
"Transformer2DModel",
"DownBlock2D",
}
attention_head_dim = (8, 16)
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
@@ -872,8 +618,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -900,8 +645,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
]
)
def test_model_xattn_mask(self, mask_dtype):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
model.to(torch_device)
@@ -931,13 +675,39 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
)
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
# maybe it's fine that this only works for the unclip use-case.
@mark.skip(
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
)
def test_model_xattn_padding(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
"""Custom Diffusion processor tests for UNet2DConditionModel."""
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
model.to(torch_device)
model.eval()
cond = inputs_dict["encoder_hidden_states"]
with torch.no_grad():
full_cond_out = model(**inputs_dict).sample
assert full_cond_out is not None
batch, tokens, _ = cond.shape
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
assert trunc_mask_out.allclose(keeplast_out), (
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
)
def test_custom_diffusion_processors(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -963,8 +733,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
assert (sample1 - sample2).abs().max() < 3e-3
def test_custom_diffusion_save_load(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -984,7 +754,7 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
@@ -1003,8 +773,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_custom_diffusion_xformers_on_off(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -1028,28 +798,41 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
def test_pickle(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
"""IP Adapter tests for UNet2DConditionModel."""
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@property
def ip_adapter_processor_cls(self):
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
model = self.model_class(**init_dict)
model.to(torch_device)
def create_ip_adapter_state_dict(self, model):
return create_ip_adapter_state_dict(model)
with torch.no_grad():
sample = model(**inputs_dict).sample
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
return inputs_dict
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
def test_asymmetrical_unet(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# Add asymmetry to configs
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
output = model(**inputs_dict).sample
expected_shape = inputs_dict["sample"].shape
# Check if input and output shapes are the same
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_ip_adapter(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -1122,8 +905,7 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
def test_ip_adapter_plus(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = (8, 16)
@@ -1195,16 +977,185 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for UNet2DConditionModel."""
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
def test_torch_compile_repeated_blocks(self):
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
loaded_model = loaded_model.to(torch_device)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
@parameterized.expand(
[
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
]
)
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_accelerator
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
loaded_model = self.model_class.from_pretrained(
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
)
new_output = loaded_model(**inputs_dict)
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)
@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
with self.assertWarns(FutureWarning) as warning:
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
warning_message = str(warning.warnings[0].message)
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
# import to still check for the rest of the stuff.
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
"LoRA injected UNet should produce different results."
)
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
"Loading from a saved checkpoint should produce identical results."
)
@require_peft_backend
def test_save_attn_procs_raise_warning(self):
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertWarns(FutureWarning) as warning:
model.save_attn_procs(tmpdirname)
warning_message = str(warning.warnings[0].message)
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for UNet2DConditionModel."""
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
def prepare_init_args_and_inputs_for_common(self):
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
@slow

View File

@@ -18,44 +18,47 @@ import unittest
import numpy as np
import torch
from diffusers import UNet3DConditionModel
from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
skip_mps,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
)
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
enable_full_determinism()
logger = logging.get_logger(__name__)
@skip_mps
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNet3DConditionModel testing."""
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"
@property
def model_class(self):
return UNet3DConditionModel
def dummy_input(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 4, 16, 16)
@property
def output_shape(self):
return (4, 4, 16, 16)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (4, 8),
"norm_num_groups": 4,
"down_block_types": (
@@ -70,25 +73,27 @@ class UNet3DConditionTesterConfig(BaseModelTesterConfig):
"layers_per_block": 1,
"sample_size": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
num_frames = 4
sizes = (16, 16)
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
return {
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
}
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
# Overriding to set `norm_num_groups` needs to be different for this model.
def test_forward_with_norm_groups(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
@@ -102,74 +107,39 @@ class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTes
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
# Overriding since the UNet3D outputs a different structure.
@torch.no_grad()
def test_determinism(self):
model = self.model_class(**self.get_init_dict())
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
inputs_dict = self.get_dummy_inputs()
with torch.no_grad():
# Warmup pass when using mps (see #372)
if torch_device == "mps" and isinstance(model, ModelMixin):
model(**self.dummy_input)
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
first = model(**inputs_dict)
if isinstance(first, dict):
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
out_1 = out_1[~np.isnan(out_1)]
out_2 = out_2[~np.isnan(out_2)]
max_diff = np.amax(np.abs(out_1 - out_2))
assert max_diff <= 1e-5
def test_feed_forward_chunking(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
assert output.shape == output_2.shape, "Shape doesn't match"
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
self.assertLessEqual(max_diff, 1e-5)
def test_model_attention_slicing(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (16, 32)
init_dict["attention_head_dim"] = 8
@@ -192,3 +162,22 @@ class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterM
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_feed_forward_chunking(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["block_out_channels"] = (32, 64)
init_dict["norm_num_groups"] = 32
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)[0]
model.enable_forward_chunking()
with torch.no_grad():
output_2 = model(**inputs_dict)[0]
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2

View File

@@ -13,42 +13,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import pytest
import torch
from torch import nn
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
from diffusers.utils import logging
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetControlNetXSModel testing."""
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetControlNetXSModel
main_input_name = "sample"
@property
def model_class(self):
return UNetControlNetXSModel
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
conditioning_scale = 1
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"controlnet_cond": controlnet_cond,
"conditioning_scale": conditioning_scale,
}
@property
def input_shape(self):
return (4, 16, 16)
@property
def output_shape(self):
return (4, 16, 16)
@property
def main_input_name(self):
return "sample"
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 16,
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
@@ -63,23 +80,11 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
"ctrl_max_norm_num_groups": 2,
"ctrl_conditioning_embedding_out_channels": (2, 2),
}
def get_dummy_inputs(self):
batch_size = 4
num_channels = 4
sizes = (16, 16)
conditioning_image_size = (3, 32, 32)
return {
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
"timestep": torch.tensor([10]).to(torch_device),
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
"conditioning_scale": 1,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_unet(self):
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
return UNet2DConditionModel(
block_out_channels=(4, 8),
layers_per_block=2,
@@ -94,16 +99,10 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
)
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
"""Build the ControlNetXS-Adapter from a UNet."""
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
pass
def test_from_unet(self):
unet = self.get_dummy_unet()
controlnet = self.get_dummy_controlnet_from_unet(unet)
@@ -116,7 +115,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
# # check unet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_unet = [
"time_embedding",
"conv_in",
@@ -153,7 +152,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
# # check controlnet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_controlnet = {
"controlnet_cond_embedding": "controlnet_cond_embedding",
"conv_in": "ctrl_conv_in",
@@ -194,12 +193,12 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
for p in module.parameters():
assert p.requires_grad
init_dict = self.get_init_dict()
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = UNetControlNetXSModel(**init_dict)
model.freeze_unet_params()
# # check unet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_unet = [
model.base_time_embedding,
model.base_conv_in,
@@ -237,7 +236,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert_frozen(u.upsamplers)
# # check controlnet
# everything except down,mid,up blocks
# everything expect down,mid,up blocks
modules_from_controlnet = [
model.controlnet_cond_embedding,
model.ctrl_conv_in,
@@ -268,6 +267,16 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
for u in model.up_blocks:
assert_unfrozen(u.ctrl_to_base)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@is_flaky
def test_forward_no_control(self):
unet = self.get_dummy_unet()
@@ -278,7 +287,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
unet = unet.to(torch_device)
model = model.to(torch_device)
input_ = self.get_dummy_inputs()
input_ = self.dummy_input
control_specific_input = ["controlnet_cond", "conditioning_scale"]
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
@@ -303,7 +312,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
model = model.to(torch_device)
model_mix_time = model_mix_time.to(torch_device)
input_ = self.get_dummy_inputs()
input_ = self.dummy_input
with torch.no_grad():
output = model(**input_).sample
@@ -311,14 +320,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
assert output.shape == output_mix_time.shape
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"Transformer2DModel",
"UNetMidBlock2DCrossAttn",
"ControlNetXSCrossAttnDownBlock2D",
"ControlNetXSCrossAttnMidBlock2D",
"ControlNetXSCrossAttnUpBlock2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test not supported.")
def test_forward_with_norm_groups(self):
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
pass

View File

@@ -16,10 +16,10 @@
import copy
import unittest
import pytest
import torch
from diffusers import UNetSpatioTemporalConditionModel
from diffusers.utils import logging
from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
@@ -28,34 +28,45 @@ from ...testing_utils import (
skip_mps,
torch_device,
)
from ..test_modeling_common import UNetTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
logger = logging.get_logger(__name__)
enable_full_determinism()
@skip_mps
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
model_class = UNetSpatioTemporalConditionModel
main_input_name = "sample"
@property
def model_class(self):
return UNetSpatioTemporalConditionModel
def dummy_input(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
@property
def input_shape(self):
return (2, 2, 4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
@property
def main_input_name(self):
return "sample"
@property
def fps(self):
return 6
@@ -72,8 +83,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
def addition_time_embed_dim(self):
return 32
def get_init_dict(self):
return {
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": (
"CrossAttnDownBlockSpatioTemporal",
@@ -92,23 +103,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
"addition_time_embed_dim": self.addition_time_embed_dim,
}
def get_dummy_inputs(self):
batch_size = 2
num_frames = 2
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
return {
"sample": noise,
"timestep": time_step,
"encoder_hidden_states": encoder_hidden_states,
"added_time_ids": self._get_add_time_ids(),
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def _get_add_time_ids(self, do_classifier_free_guidance=True):
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
@@ -128,15 +124,43 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
return add_time_ids
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
@pytest.mark.skip("Number of Norm Groups is not configurable")
@unittest.skip("Number of Norm Groups is not configurable")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("Deprecated functionality")
def test_model_attention_slicing(self):
pass
@unittest.skip("Not supported")
def test_model_with_use_linear_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_simple_projection(self):
pass
@unittest.skip("Not supported")
def test_model_with_class_embeddings_concat(self):
pass
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
def test_model_with_num_attention_heads_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
model = self.model_class(**init_dict)
@@ -149,13 +173,12 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_cross_attention_dim_tuple(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["cross_attention_dim"] = (32, 32)
@@ -169,13 +192,27 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
if isinstance(output, dict):
output = output.sample
assert output is not None
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
num_attention_heads = (8, 16)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, num_attention_heads=num_attention_heads
)
def test_pickle(self):
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["num_attention_heads"] = (8, 16)
@@ -188,33 +225,3 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
sample_copy = copy.copy(sample)
assert (sample - sample_copy).abs().max() < 1e-4
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict = self.get_init_dict()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
== "XFormersAttnProcessor"
), "xformers is not enabled"
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"TransformerSpatioTemporalModel",
"CrossAttnDownBlockSpatioTemporal",
"DownBlockSpatioTemporal",
"UpBlockSpatioTemporal",
"CrossAttnUpBlockSpatioTemporal",
"UNetMidBlockSpatioTemporal",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -0,0 +1,74 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import tempfile
import unittest
from diffusers import AutoPipelineForText2Image
from diffusers.models.auto_model import AutoModel
from ..testing_utils import is_torch_available, require_flashpack, require_torch_gpu
if is_torch_available():
import torch
class FlashPackTests(unittest.TestCase):
model_id: str = "hf-internal-testing/tiny-flux-pipe"
@require_flashpack
def test_save_load_model(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
self.assertTrue((pathlib.Path(temp_dir) / "model.flashpack").exists())
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True)
@require_flashpack
def test_save_load_pipeline(self):
pipeline = AutoPipelineForText2Image.from_pretrained(self.model_id)
with tempfile.TemporaryDirectory() as temp_dir:
pipeline.save_pretrained(temp_dir, use_flashpack=True)
self.assertTrue((pathlib.Path(temp_dir) / "transformer" / "model.flashpack").exists())
self.assertTrue((pathlib.Path(temp_dir) / "vae" / "model.flashpack").exists())
pipeline = AutoPipelineForText2Image.from_pretrained(temp_dir, use_flashpack=True)
@require_torch_gpu
@require_flashpack
def test_load_model_device_str(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "cuda"})
self.assertTrue(model.device.type == "cuda")
@require_torch_gpu
@require_flashpack
def test_load_model_device(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": torch.device("cuda")})
self.assertTrue(model.device.type == "cuda")
@require_flashpack
def test_load_model_device_auto(self):
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
with tempfile.TemporaryDirectory() as temp_dir:
model.save_pretrained(temp_dir, use_flashpack=True)
with self.assertRaises(ValueError):
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "auto"})

View File

@@ -34,6 +34,7 @@ from diffusers.utils.import_utils import (
is_accelerate_available,
is_bitsandbytes_available,
is_compel_available,
is_flashpack_available,
is_flax_available,
is_gguf_available,
is_kernels_available,
@@ -737,6 +738,13 @@ def require_accelerate(test_case):
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
def require_flashpack(test_case):
"""
Decorator marking a test that requires flashpack. These tests are skipped when flashpack isn't installed.
"""
return pytest.mark.skipif(not is_flashpack_available(), reason="test requires flashpack")(test_case)
def require_peft_version_greater(peft_version):
"""
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific