mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-16 20:57:10 +08:00
Compare commits
8 Commits
unet-model
...
bnb-test-f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97ee35f826 | ||
|
|
80ad468e2b | ||
|
|
fe38d77603 | ||
|
|
526498d219 | ||
|
|
6a339ce637 | ||
|
|
26bb7fa0cb | ||
|
|
5063aa5566 | ||
|
|
62b1071609 |
@@ -1533,9 +1533,9 @@ def main(args):
|
||||
# from the cat above, but collate_fn also doubles the prompts list. Use half the
|
||||
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
|
||||
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
|
||||
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
|
||||
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
|
||||
if prompt_embeds_mask is not None:
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)
|
||||
prompt_embeds_mask = prompt_embeds_mask.repeat_interleave(num_repeat_elements, dim=0)
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
model_input = latents_cache[step].sample()
|
||||
@@ -1602,10 +1602,11 @@ def main(args):
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
weighting, weighting_prior = torch.chunk(weighting, 2, dim=0)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = torch.mean(
|
||||
(weighting.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
(weighting_prior.float() * (model_pred_prior.float() - target_prior.float()) ** 2).reshape(
|
||||
target_prior.shape[0], -1
|
||||
),
|
||||
1,
|
||||
|
||||
2
setup.py
2
setup.py
@@ -146,6 +146,7 @@ _deps = [
|
||||
"phonemizer",
|
||||
"opencv-python",
|
||||
"timm",
|
||||
"flashpack",
|
||||
]
|
||||
|
||||
# this is a lookup table with items like:
|
||||
@@ -250,6 +251,7 @@ extras["gguf"] = deps_list("gguf", "accelerate")
|
||||
extras["optimum_quanto"] = deps_list("optimum_quanto", "accelerate")
|
||||
extras["torchao"] = deps_list("torchao", "accelerate")
|
||||
extras["nvidia_modelopt"] = deps_list("nvidia_modelopt[hf]")
|
||||
extras["flashpack"] = deps_list("flashpack")
|
||||
|
||||
if os.name == "nt": # windows
|
||||
extras["flax"] = [] # jax is not supported on windows
|
||||
|
||||
@@ -53,4 +53,5 @@ deps = {
|
||||
"phonemizer": "phonemizer",
|
||||
"opencv-python": "opencv-python",
|
||||
"timm": "timm",
|
||||
"flashpack": "flashpack",
|
||||
}
|
||||
|
||||
@@ -540,7 +540,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_version(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
||||
from ..quantizers.quantization_config import QuantizationMethod
|
||||
from ..utils import (
|
||||
CONFIG_NAME,
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -55,6 +56,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
is_flashpack_available,
|
||||
is_peft_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
@@ -673,6 +675,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str = "10GB",
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -725,7 +728,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
||||
)
|
||||
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = WEIGHTS_NAME
|
||||
if use_flashpack:
|
||||
weights_name = FLASHPACK_WEIGHTS_NAME
|
||||
elif safe_serialization:
|
||||
weights_name = SAFETENSORS_WEIGHTS_NAME
|
||||
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
@@ -752,58 +760,74 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
if use_flashpack:
|
||||
if is_flashpack_available():
|
||||
import flashpack
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
logger.error(
|
||||
"Saving a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
|
||||
)
|
||||
raise ImportError("Please install torch and flashpack to save a FlashPack checkpoint in PyTorch.")
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
flashpack.serialization.pack_to_file(
|
||||
state_dict_or_model=state_dict,
|
||||
destination_path=os.path.join(save_directory, weights_name),
|
||||
target_dtype=self.dtype,
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
# Save the model
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
|
||||
)
|
||||
|
||||
# Clean the folder from a previous save
|
||||
if is_main_process:
|
||||
for filename in os.listdir(save_directory):
|
||||
if filename in state_dict_split.filename_to_tensors.keys():
|
||||
continue
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
if not os.path.isfile(full_filename):
|
||||
continue
|
||||
weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
|
||||
weights_without_ext = weights_without_ext.replace("{suffix}", "")
|
||||
filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
if (
|
||||
filename.startswith(weights_without_ext)
|
||||
and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
for filename, tensors in state_dict_split.filename_to_tensors.items():
|
||||
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
|
||||
filepath = os.path.join(save_directory, filename)
|
||||
if safe_serialization:
|
||||
# At some point we will need to deal better with save_function (used for TPU and other distributed
|
||||
# joyfulness), but for now this enough.
|
||||
safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
|
||||
else:
|
||||
torch.save(shard, filepath)
|
||||
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
logger.info(
|
||||
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
|
||||
f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
else:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
|
||||
if push_to_hub:
|
||||
# Create a new empty model card and eventually tag it
|
||||
@@ -940,6 +964,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the model is loaded from `flashpack` weights.
|
||||
flashpack_kwargs(`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Kwargs passed to
|
||||
[`flashpack.deserialization.assign_from_file`](https://github.com/fal-ai/flashpack/blob/f1aa91c5cd9532a3dbf5bcc707ab9b01c274b76c/src/flashpack/deserialization.py#L408-L422)
|
||||
|
||||
|
||||
> [!TIP] > To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in
|
||||
with `hf > auth login`. You can also activate the special >
|
||||
@@ -984,6 +1014,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries: dict[str, DDUFEntry] | None = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
parallel_config: ParallelConfig | ContextParallelConfig | None = kwargs.pop("parallel_config", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
flashpack_kwargs = kwargs.pop("flashpack_kwargs", {})
|
||||
|
||||
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
||||
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
||||
@@ -1212,30 +1244,37 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
subfolder=subfolder or "",
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
elif use_safetensors:
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
else:
|
||||
if use_flashpack:
|
||||
weights_name = FLASHPACK_WEIGHTS_NAME
|
||||
elif use_safetensors:
|
||||
weights_name = _add_variant(SAFETENSORS_WEIGHTS_NAME, variant)
|
||||
else:
|
||||
weights_name = None
|
||||
if weights_name is not None:
|
||||
try:
|
||||
resolved_model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
dduf_entries=dduf_entries,
|
||||
)
|
||||
|
||||
except IOError as e:
|
||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
except IOError as e:
|
||||
logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
|
||||
if not allow_pickle:
|
||||
raise
|
||||
logger.warning(
|
||||
"Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
|
||||
)
|
||||
|
||||
if resolved_model_file is None and not is_sharded:
|
||||
resolved_model_file = _get_model_file(
|
||||
@@ -1275,6 +1314,44 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
with ContextManagers(init_contexts):
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
if use_flashpack:
|
||||
if is_flashpack_available():
|
||||
import flashpack
|
||||
else:
|
||||
logger.error(
|
||||
"Loading a FlashPack checkpoint in PyTorch, requires both PyTorch and flashpack to be installed. Please see "
|
||||
"https://pytorch.org/ and https://github.com/fal-ai/flashpack for installation instructions."
|
||||
)
|
||||
raise ImportError("Please install torch and flashpack to load a FlashPack checkpoint in PyTorch.")
|
||||
|
||||
if device_map is None:
|
||||
logger.warning(
|
||||
"`device_map` has not been provided for FlashPack, model will be on `cpu` - provide `device_map` to fully utilize "
|
||||
"the benefit of FlashPack."
|
||||
)
|
||||
flashpack_device = torch.device("cpu")
|
||||
else:
|
||||
device = device_map[""]
|
||||
if isinstance(device, str) and device in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||
raise ValueError(
|
||||
"FlashPack `device_map` should not be one of `auto`, `balanced`, `balanced_low_0`, `sequential`. Use a specific device instead, e.g., `device_map='cuda'` or `device_map='cuda:0'"
|
||||
)
|
||||
flashpack_device = torch.device(device) if not isinstance(device, torch.device) else device
|
||||
|
||||
flashpack.mixin.assign_from_file(
|
||||
model=model,
|
||||
path=resolved_model_file[0],
|
||||
device=flashpack_device,
|
||||
**flashpack_kwargs,
|
||||
)
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
if output_loading_info:
|
||||
logger.warning("`output_loading_info` is not supported with FlashPack.")
|
||||
return model, {}
|
||||
|
||||
return model
|
||||
|
||||
if dtype_orig is not None:
|
||||
torch.set_default_dtype(dtype_orig)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class ErnieImageTransformer2DModelOutput(BaseOutput):
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = torch.einsum("...n,d->...nd", pos, omega)
|
||||
return out.float()
|
||||
@@ -400,8 +400,8 @@ class ErnieImageTransformer2DModel(ModelMixin, ConfigMixin):
|
||||
]
|
||||
|
||||
# AdaLN
|
||||
sample = self.time_proj(timestep.to(dtype))
|
||||
sample = sample.to(self.time_embedding.linear_1.weight.dtype)
|
||||
sample = self.time_proj(timestep)
|
||||
sample = sample.to(dtype=dtype)
|
||||
c = self.time_embedding(sample)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = [
|
||||
t.unsqueeze(0).expand(S, -1, -1).contiguous() for t in self.adaLN_modulation(c).chunk(6, dim=-1)
|
||||
|
||||
@@ -445,10 +445,14 @@ class WanAnimateFaceBlockAttnProcessor:
|
||||
# B --> batch_size, T --> reduced inference segment len, N --> face_encoder_num_heads + 1, C --> attn.dim
|
||||
B, T, N, C = encoder_hidden_states.shape
|
||||
|
||||
# Flatten T and N so the K/V projections see a 3D tensor; BnB int8 matmul only
|
||||
# accepts 2D/3D inputs and would otherwise fail on this 4D activation.
|
||||
encoder_hidden_states = encoder_hidden_states.flatten(1, 2) # [B, T, N, C] --> [B, T * N, C]
|
||||
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)) # [B, S, H * D] --> [B, S, H, D]
|
||||
key = key.view(B, T, N, attn.heads, -1) # [B, T, N, H * D_kv] --> [B, T, N, H, D_kv]
|
||||
key = key.view(B, T, N, attn.heads, -1) # [B, T * N, H * D_kv] --> [B, T, N, H, D_kv]
|
||||
value = value.view(B, T, N, attn.heads, -1)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
|
||||
@@ -331,7 +331,7 @@ class WanVACETransformer3DModel(
|
||||
)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
|
||||
else:
|
||||
# Prepare VACE hints
|
||||
control_hidden_states_list = []
|
||||
@@ -346,7 +346,7 @@ class WanVACETransformer3DModel(
|
||||
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
||||
if i in self.config.vace_layers:
|
||||
control_hint, scale = control_hidden_states_list.pop()
|
||||
hidden_states = hidden_states + control_hint * scale
|
||||
hidden_states = hidden_states + control_hint.to(hidden_states.device) * scale
|
||||
|
||||
# 6. Output norm, projection & unpatchify
|
||||
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
|
||||
|
||||
@@ -877,10 +877,7 @@ class FluxPipeline(
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
timestep_device = device
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
|
||||
@@ -28,6 +28,7 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from ..utils import (
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
@@ -194,6 +195,7 @@ def filter_model_files(filenames):
|
||||
FLAX_WEIGHTS_NAME,
|
||||
ONNX_WEIGHTS_NAME,
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME,
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
]
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -413,6 +415,9 @@ def get_class_obj_and_candidates(
|
||||
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
||||
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
||||
|
||||
if class_name.startswith("FlashPack"):
|
||||
class_name = class_name.removeprefix("FlashPack")
|
||||
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
|
||||
@@ -760,6 +765,7 @@ def load_sub_model(
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
quantization_config: Any | None = None,
|
||||
use_flashpack: bool = False,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
from ..quantizers import PipelineQuantizationConfig
|
||||
@@ -838,6 +844,9 @@ def load_sub_model(
|
||||
loading_kwargs["variant"] = model_variants.pop(name, None)
|
||||
loading_kwargs["use_safetensors"] = use_safetensors
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["use_flashpack"] = use_flashpack
|
||||
|
||||
if from_flax:
|
||||
loading_kwargs["from_flax"] = True
|
||||
|
||||
@@ -887,7 +896,7 @@ def load_sub_model(
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
|
||||
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict):
|
||||
if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict) and not use_flashpack:
|
||||
# remove hooks
|
||||
remove_hook_from_module(loaded_sub_model, recurse=True)
|
||||
needs_offloading_to_cpu = device_map[""] == "cpu"
|
||||
@@ -1093,6 +1102,7 @@ def _get_ignore_patterns(
|
||||
allow_pickle: bool,
|
||||
use_onnx: bool,
|
||||
is_onnx: bool,
|
||||
use_flashpack: bool,
|
||||
variant: str | None = None,
|
||||
) -> list[str]:
|
||||
if (
|
||||
@@ -1118,6 +1128,9 @@ def _get_ignore_patterns(
|
||||
if not use_onnx:
|
||||
ignore_patterns += ["*.onnx", "*.pb"]
|
||||
|
||||
elif use_flashpack:
|
||||
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb", "*.msgpack"]
|
||||
|
||||
else:
|
||||
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
||||
|
||||
|
||||
@@ -244,6 +244,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
use_flashpack: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -341,6 +342,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
save_method_accept_flashpack = "use_flashpack" in save_method_signature.parameters
|
||||
save_method_accept_peft_format = "save_peft_format" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
@@ -351,6 +353,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
# max_shard_size is expected to not be None in ModelMixin
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
if save_method_accept_flashpack:
|
||||
save_kwargs["use_flashpack"] = use_flashpack
|
||||
if save_method_accept_peft_format:
|
||||
# Set save_peft_format=False for transformers>=5.0.0 compatibility
|
||||
# In transformers 5.0.0+, the default save_peft_format=True adds "base_model.model" prefix
|
||||
@@ -781,6 +785,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
||||
@@ -1071,6 +1076,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
quantization_config=quantization_config,
|
||||
use_flashpack=use_flashpack,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
@@ -1576,6 +1582,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
Whether or not to allow for custom pipelines and components defined on the Hub in their own files. This
|
||||
option should only be set to `True` for repositories you trust and in which you have read the code, as
|
||||
it will execute code present on the Hub on your local machine.
|
||||
use_flashpack (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, FlashPack weights will always be downloaded if present. If set to `False`, FlashPack
|
||||
weights will never be downloaded.
|
||||
|
||||
Returns:
|
||||
`os.PathLike`:
|
||||
@@ -1600,6 +1609,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
dduf_file: dict[str, DDUFEntry] | None = kwargs.pop("dduf_file", None)
|
||||
use_flashpack = kwargs.pop("use_flashpack", False)
|
||||
|
||||
if dduf_file:
|
||||
if custom_pipeline:
|
||||
@@ -1719,6 +1729,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
allow_pickle,
|
||||
use_onnx,
|
||||
pipeline_class._is_onnx,
|
||||
use_flashpack,
|
||||
variant,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ from .constants import (
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
DIFFUSERS_LOAD_ID_FIELDS,
|
||||
FLASHPACK_FILE_EXTENSION,
|
||||
FLASHPACK_WEIGHTS_NAME,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
GGUF_FILE_EXTENSION,
|
||||
HF_ENABLE_PARALLEL_LOADING,
|
||||
@@ -76,6 +78,7 @@ from .import_utils import (
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_flashpack_available,
|
||||
is_flax_available,
|
||||
is_ftfy_available,
|
||||
is_gguf_available,
|
||||
|
||||
@@ -34,6 +34,8 @@ ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
|
||||
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
||||
FLASHPACK_WEIGHTS_NAME = "model.flashpack"
|
||||
FLASHPACK_FILE_EXTENSION = "flashpack"
|
||||
GGUF_FILE_EXTENSION = "gguf"
|
||||
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
||||
|
||||
@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
|
||||
_aiter_available, _aiter_version = _is_package_available("aiter", get_dist_name=True)
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_flashpack_available, _flashpack_version = _is_package_available("flashpack")
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
|
||||
|
||||
@@ -361,6 +362,10 @@ def is_gguf_available():
|
||||
return _gguf_available
|
||||
|
||||
|
||||
def is_flashpack_available():
|
||||
return _flashpack_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _torchao_available
|
||||
|
||||
|
||||
@@ -465,8 +465,7 @@ class UNetTesterMixin:
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
@@ -481,9 +480,9 @@ class UNetTesterMixin:
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
|
||||
class ModelTesterMixin:
|
||||
|
||||
@@ -205,6 +205,11 @@ class BaseModelTesterConfig:
|
||||
"""Additional kwargs to pass to from_pretrained (e.g., subfolder, variant)."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def torch_dtype(self) -> torch.dtype:
|
||||
"""Compute dtype used to build dummy inputs and cast inputs where needed."""
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def output_shape(self) -> Optional[tuple]:
|
||||
"""Expected output shape for output validation tests."""
|
||||
@@ -287,9 +292,8 @@ class ModelTesterMixin:
|
||||
f"Parameter shape mismatch for {param_name}. Original: {param_1.shape}, loaded: {param_2.shape}"
|
||||
)
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -309,9 +313,8 @@ class ModelTesterMixin:
|
||||
|
||||
new_model.to(torch_device)
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
image = model(**inputs_dict, return_dict=False)[0]
|
||||
new_image = new_model(**inputs_dict, return_dict=False)[0]
|
||||
image = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
new_image = new_model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(image, new_image, atol=atol, rtol=rtol, msg="Models give different forward passes.")
|
||||
|
||||
@@ -339,9 +342,8 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
first = model(**inputs_dict, return_dict=False)[0]
|
||||
second = model(**inputs_dict, return_dict=False)[0]
|
||||
first = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
second = model(**self.get_dummy_inputs(), return_dict=False)[0]
|
||||
|
||||
first_flat = first.flatten()
|
||||
second_flat = second.flatten()
|
||||
@@ -398,9 +400,8 @@ class ModelTesterMixin:
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
outputs_dict = model(**self.get_dummy_inputs())
|
||||
outputs_tuple = model(**self.get_dummy_inputs(), return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
@@ -527,10 +528,8 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after sharded save/load"
|
||||
@@ -569,10 +568,8 @@ class ModelTesterMixin:
|
||||
new_model = new_model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict, return_dict=False)[0]
|
||||
inputs_dict_new = self.get_dummy_inputs()
|
||||
new_output = new_model(**inputs_dict_new, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, new_output, atol=atol, rtol=rtol, msg="Output should match after variant sharded save/load"
|
||||
@@ -622,10 +619,8 @@ class ModelTesterMixin:
|
||||
model_parallel = model_parallel.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
# Re-create inputs only if they contain a generator (which needs to be reset)
|
||||
if "generator" in inputs_dict:
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict, return_dict=False)[0]
|
||||
inputs_dict_parallel = self.get_dummy_inputs()
|
||||
output_parallel = model_parallel(**inputs_dict_parallel, return_dict=False)[0]
|
||||
|
||||
assert_tensors_close(
|
||||
base_output, output_parallel, atol=atol, rtol=rtol, msg="Output should match with parallel loading"
|
||||
|
||||
@@ -92,6 +92,9 @@ class TorchCompileTesterMixin:
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
|
||||
@@ -359,15 +359,7 @@ class QuantizationTesterMixin:
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
assert not self._is_module_quantized(module), f"Module {name} is still quantized after dequantize()"
|
||||
|
||||
# Get model dtype from first parameter
|
||||
model_dtype = next(model.parameters()).dtype
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
# Cast inputs to model dtype
|
||||
inputs = {
|
||||
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
assert output is not None, "Model output is None after dequantization"
|
||||
assert not torch.isnan(output).any(), "Model output contains NaN after dequantization"
|
||||
@@ -575,33 +567,28 @@ class BitsAndBytesTesterMixin(BitsAndBytesConfigMixin, QuantizationTesterMixin):
|
||||
|
||||
@torch.no_grad()
|
||||
def test_bnb_keep_modules_in_fp32(self):
|
||||
if not hasattr(self.model_class, "_keep_in_fp32_modules"):
|
||||
pytest.skip(f"{self.model_class.__name__} does not have _keep_in_fp32_modules")
|
||||
fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
if not fp32_modules:
|
||||
pytest.skip(f"{self.model_class.__name__} does not declare _keep_in_fp32_modules")
|
||||
|
||||
config_kwargs = BitsAndBytesConfigMixin.BNB_CONFIGS["4bit_nf4"]
|
||||
|
||||
original_fp32_modules = getattr(self.model_class, "_keep_in_fp32_modules", None)
|
||||
self.model_class._keep_in_fp32_modules = ["proj_out"]
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
model.to(torch_device)
|
||||
|
||||
try:
|
||||
model = self._create_quantized_model(config_kwargs)
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if any(fp32_name in name for fp32_name in model._keep_in_fp32_modules):
|
||||
assert module.weight.dtype == torch.float32, (
|
||||
f"Module {name} should be FP32 but is {module.weight.dtype}"
|
||||
)
|
||||
else:
|
||||
assert module.weight.dtype == torch.uint8, (
|
||||
f"Module {name} should be uint8 but is {module.weight.dtype}"
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
finally:
|
||||
if original_fp32_modules is not None:
|
||||
self.model_class._keep_in_fp32_modules = original_fp32_modules
|
||||
inputs = self.get_dummy_inputs()
|
||||
_ = model(**inputs)
|
||||
|
||||
def test_bnb_modules_to_not_convert(self):
|
||||
"""Test that modules_to_not_convert parameter works correctly."""
|
||||
|
||||
@@ -159,21 +159,36 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
(batch_size, height * width, num_latent_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
(batch_size, sequence_length, embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"pooled_projections": randn_tensor(
|
||||
(batch_size, embedding_dim), generator=self.generator, device=torch_device
|
||||
(batch_size, embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
(height * width, num_image_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
(sequence_length, num_image_channels),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
@@ -320,6 +335,10 @@ class TestFluxSingleFile(FluxTransformerTesterConfig, SingleFileTesterMixin):
|
||||
class TestFluxTransformerBitsAndBytes(FluxTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux Transformer."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
|
||||
class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
|
||||
"""Quanto quantization tests for Flux Transformer."""
|
||||
|
||||
@@ -91,11 +91,13 @@ class WanTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
@@ -113,27 +113,32 @@ class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(batch_size, clip_seq_len, clip_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(batch_size, 3, inference_segment_length, face_height, face_width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -96,16 +96,19 @@ class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(batch_size, vace_in_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
dtype=self.torch_dtype,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
@@ -24,39 +26,64 @@ from ...testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
_LAYERWISE_CASTING_XFAIL_REASON = (
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
)
|
||||
|
||||
|
||||
class UNet1DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (standard variant)."""
|
||||
class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (14, 16)
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
super().test_output()
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (8, 8, 16, 16),
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
@@ -70,40 +97,18 @@ class UNet1DTesterConfig(BaseModelTesterConfig):
|
||||
"up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"),
|
||||
"act_fn": "swish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1D(UNet1DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestUNet1DMemory(UNet1DTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet"
|
||||
)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -126,7 +131,12 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348])
|
||||
# fmt: on
|
||||
assert torch.allclose(output_slice, expected_output_slice, rtol=1e-3)
|
||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_unet_1d_maestro(self):
|
||||
@@ -147,29 +157,98 @@ class TestUNet1DHubLoading(UNet1DTesterConfig):
|
||||
assert (output_sum - 224.0896).abs() < 0.5
|
||||
assert (output_max - 0.0607).abs() < 4e-4
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
|
||||
# =============================================================================
|
||||
# UNet1D RL (Value Function) Model Tests
|
||||
# =============================================================================
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
|
||||
class UNet1DRLTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet1DModel testing (RL value function variant)."""
|
||||
class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet1DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet1DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device)
|
||||
time_step = torch.tensor([10] * batch_size).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 14, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1,)
|
||||
return (4, 14, 1)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
def test_determinism(self):
|
||||
super().test_determinism()
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
def test_from_save_pretrained(self):
|
||||
super().test_from_save_pretrained()
|
||||
|
||||
def test_from_save_pretrained_variant(self):
|
||||
super().test_from_save_pretrained_variant()
|
||||
|
||||
def test_model_from_pretrained(self):
|
||||
super().test_model_from_pretrained()
|
||||
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function is different output shape
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_ema_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 14,
|
||||
"out_channels": 14,
|
||||
"down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"],
|
||||
@@ -185,54 +264,18 @@ class UNet1DRLTesterConfig(BaseModelTesterConfig):
|
||||
"time_embedding_type": "positional",
|
||||
"act_fn": "mish",
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_features = 14
|
||||
seq_len = 16
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_features, seq_len)).to(torch_device),
|
||||
"timestep": torch.tensor([10] * batch_size).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet1DRL(UNet1DRLTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Not implemented yet for this UNet")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@torch.no_grad()
|
||||
def test_output(self):
|
||||
# UNetRL is a value-function with different output shape (batch, 1)
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1))
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet1DRLMemory(UNet1DRLTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.xfail(reason=_LAYERWISE_CASTING_XFAIL_REASON)
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
value_function, vf_loading_info = UNet1DModel.from_pretrained(
|
||||
"bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function"
|
||||
)
|
||||
assert value_function is not None
|
||||
assert len(vf_loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(value_function)
|
||||
self.assertEqual(len(vf_loading_info["missing_keys"]), 0)
|
||||
|
||||
value_function.to(torch_device)
|
||||
image = value_function(**self.get_dummy_inputs())
|
||||
image = value_function(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -256,4 +299,31 @@ class TestUNet1DRLHubLoading(UNet1DRLTesterConfig):
|
||||
# fmt: off
|
||||
expected_output_slice = torch.tensor([165.25] * seq_len)
|
||||
# fmt: on
|
||||
assert torch.allclose(output, expected_output_slice, rtol=1e-3)
|
||||
self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# Not implemented yet for this UNet
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"RuntimeError: 'fill_out' not implemented for 'Float8_e4m3fn'. The error is caused due to certain torch.float8_e4m3fn and torch.float8_e5m2 operations "
|
||||
"not being supported when using deterministic algorithms (which is what the tests run with). To fix:\n"
|
||||
"1. Wait for next PyTorch release: https://github.com/pytorch/pytorch/issues/137160.\n"
|
||||
"2. Unskip this test."
|
||||
),
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
|
||||
import gc
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
@@ -30,40 +31,39 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Standard UNet2D Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for standard UNet2DModel testing."""
|
||||
class Unet2DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 2,
|
||||
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
|
||||
@@ -74,22 +74,11 @@ class UNet2DTesterConfig(BaseModelTesterConfig):
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_mid_block_attn_groups(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["add_attention"] = True
|
||||
init_dict["attn_norm_num_groups"] = 4
|
||||
@@ -98,11 +87,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
assert model.mid_block.attentions[0].group_norm is not None, (
|
||||
"Mid block Attention group norm should exist but does not."
|
||||
self.assertIsNotNone(
|
||||
model.mid_block.attentions[0].group_norm, "Mid block Attention group norm should exist but does not."
|
||||
)
|
||||
assert model.mid_block.attentions[0].group_norm.num_groups == init_dict["attn_norm_num_groups"], (
|
||||
"Mid block Attention group norm does not have the expected number of groups."
|
||||
self.assertEqual(
|
||||
model.mid_block.attentions[0].group_norm.num_groups,
|
||||
init_dict["attn_norm_num_groups"],
|
||||
"Mid block Attention group norm does not have the expected number of groups.",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -111,15 +102,13 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_mid_block_none(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
mid_none_init_dict = self.get_init_dict()
|
||||
mid_none_inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
mid_none_init_dict["mid_block_type"] = None
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -130,7 +119,7 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
mid_none_model.to(torch_device)
|
||||
mid_none_model.eval()
|
||||
|
||||
assert mid_none_model.mid_block is None, "Mid block should not exist."
|
||||
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
@@ -144,10 +133,8 @@ class TestUNet2D(UNet2DTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
if isinstance(mid_none_output, dict):
|
||||
mid_none_output = mid_none_output.to_tuple()[0]
|
||||
|
||||
assert not torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different."
|
||||
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
|
||||
|
||||
|
||||
class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"AttnUpBlock2D",
|
||||
@@ -156,32 +143,41 @@ class TestUNet2DTraining(UNet2DTesterConfig, TrainingTesterMixin):
|
||||
"UpBlock2D",
|
||||
"DownBlock2D",
|
||||
}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
attention_head_dim = 8
|
||||
block_out_channels = (16, 32)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# UNet2D LDM Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UNet2DLDMTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel LDM variant testing."""
|
||||
class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 32,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
@@ -191,34 +187,17 @@ class UNet2DLDMTesterConfig(BaseModelTesterConfig):
|
||||
"down_block_types": ("DownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "UpBlock2D"),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestUNet2DLDMTraining(UNet2DLDMTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -226,7 +205,7 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
def test_from_pretrained_accelerate(self):
|
||||
model, _ = UNet2DModel.from_pretrained("fusing/unet-ldm-dummy-update", output_loading_info=True)
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs()).sample
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -286,31 +265,44 @@ class TestUNet2DLDMHubLoading(UNet2DLDMTesterConfig):
|
||||
expected_output_slice = torch.tensor([-13.3258, -20.1100, -15.9873, -17.6617, -23.0596, -17.9419, -13.3675, -16.1889, -12.3800])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-3)
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
|
||||
|
||||
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
|
||||
attention_head_dim = 32
|
||||
block_out_channels = (32, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NCSN++ Model Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class NCSNppTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DModel NCSN++ variant testing."""
|
||||
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DModel
|
||||
def dummy_input(self, sizes=(32, 32)):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": [32, 64, 64, 64],
|
||||
"in_channels": 3,
|
||||
"layers_per_block": 1,
|
||||
@@ -332,71 +324,17 @@ class NCSNppTesterConfig(BaseModelTesterConfig):
|
||||
"SkipUpBlock2D",
|
||||
],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor(batch_size * [10]).to(dtype=torch.int32, device=torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestNCSNpp(NCSNppTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_keep_in_fp32_modules(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_from_save_pretrained_dtype_inference(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppMemory(NCSNppTesterConfig, MemoryTesterMixin):
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. "
|
||||
"Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestNCSNppTraining(NCSNppTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
@slow
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256", output_loading_info=True)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
inputs = self.get_dummy_inputs()
|
||||
inputs = self.dummy_input
|
||||
noise = floats_tensor((4, 3) + (256, 256)).to(torch_device)
|
||||
inputs["sample"] = noise
|
||||
image = model(**inputs)
|
||||
@@ -423,7 +361,7 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
expected_output_slice = torch.tensor([-4836.2178, -6487.1470, -3816.8196, -7964.9302, -10966.3037, -20043.5957, 8137.0513, 2340.3328, 544.6056])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
def test_output_pretrained_ve_large(self):
|
||||
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update")
|
||||
@@ -444,4 +382,35 @@ class TestNCSNppHubLoading(NCSNppTesterConfig):
|
||||
expected_output_slice = torch.tensor([-0.0325, -0.0900, -0.0869, -0.0332, -0.0725, -0.0270, -0.0101, 0.0227, 0.0256])
|
||||
# fmt: on
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"UNetMidBlock2D",
|
||||
}
|
||||
|
||||
block_out_channels = (32, 64, 64, 64)
|
||||
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"To make layerwise casting work with this model, we will have to update the implementation. Due to potentially low usage, we don't support it here."
|
||||
)
|
||||
def test_layerwise_casting_memory(self):
|
||||
pass
|
||||
|
||||
@@ -20,7 +20,6 @@ import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from parameterized import parameterized
|
||||
@@ -53,24 +52,17 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
from ..test_modeling_common import (
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
UNetTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig
|
||||
|
||||
from ..testing_utils.lora import check_if_lora_correctly_set
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -90,6 +82,16 @@ def get_unet_lora_config():
|
||||
return unet_lora_config
|
||||
|
||||
|
||||
def check_if_lora_correctly_set(model) -> bool:
|
||||
"""
|
||||
Checks if the LoRA layers are correctly set with peft
|
||||
"""
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
@@ -352,28 +354,34 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet2DConditionModel testing."""
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
main_input_name = "sample"
|
||||
# We override the items here because the unet under consideration is small.
|
||||
model_split_percents = [0.5, 0.34, 0.4]
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet2DConditionModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int, int]:
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list[float]:
|
||||
return [0.5, 0.34, 0.4]
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
"""Return UNet2D model initialization arguments."""
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
@@ -385,24 +393,26 @@ class UNet2DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
"""Return dummy inputs for UNet2D model."""
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -417,13 +427,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
@@ -437,13 +446,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["cross_attention_dim"] = (8, 8)
|
||||
|
||||
@@ -457,13 +465,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_simple_projection(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -482,13 +489,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
batch_size, _, _, sample_size = inputs_dict["sample"].shape
|
||||
|
||||
@@ -508,287 +514,12 @@ class TestUNet2DCondition(UNet2DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty small,
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
|
||||
class TestUNet2DConditionHubLoading(UNet2DConditionTesterConfig):
|
||||
"""Hub checkpoint loading tests for UNet2DConditionModel."""
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
|
||||
class TestUNet2DConditionLoRA(UNet2DConditionTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for UNet2DConditionModel."""
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated load_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with pytest.warns(FutureWarning, match="Using the `load_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
"""Test that deprecated save_attn_procs method raises FutureWarning."""
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with pytest.warns(FutureWarning, match="Using the `save_attn_procs\\(\\)` method has been deprecated"):
|
||||
model.save_attn_procs(os.path.join(tmpdirname))
|
||||
|
||||
|
||||
class TestUNet2DConditionMemory(UNet2DConditionTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for UNet2DConditionModel."""
|
||||
|
||||
|
||||
class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for UNet2DConditionModel."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for UNet2DConditionModel."""
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -813,7 +544,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
assert output is not None
|
||||
|
||||
def test_model_sliceable_head_dim(self):
|
||||
init_dict = self.get_init_dict()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -831,6 +562,21 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
for module in model.children():
|
||||
check_sliceable_dim_attr(module)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"UpBlock2D",
|
||||
"Transformer2DModel",
|
||||
"DownBlock2D",
|
||||
}
|
||||
attention_head_dim = (8, 16)
|
||||
block_out_channels = (16, 32)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
|
||||
)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
@@ -872,8 +618,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -900,8 +645,7 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
]
|
||||
)
|
||||
def test_model_xattn_mask(self, mask_dtype):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
|
||||
model.to(torch_device)
|
||||
@@ -931,13 +675,39 @@ class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterM
|
||||
"masking the last token from our cond should be equivalent to truncating that token out of the condition"
|
||||
)
|
||||
|
||||
# see diffusers.models.attention_processor::Attention#prepare_attention_mask
|
||||
# note: we may not need to fix mask padding to work for stable-diffusion cross-attn masks.
|
||||
# since the use-case (somebody passes in a too-short cross-attn mask) is pretty esoteric.
|
||||
# maybe it's fine that this only works for the unclip use-case.
|
||||
@mark.skip(
|
||||
reason="we currently pad mask by target_length tokens (what unclip needs), whereas stable-diffusion's cross-attn needs to instead pad by remaining_length."
|
||||
)
|
||||
def test_model_xattn_padding(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
"""Custom Diffusion processor tests for UNet2DConditionModel."""
|
||||
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
cond = inputs_dict["encoder_hidden_states"]
|
||||
with torch.no_grad():
|
||||
full_cond_out = model(**inputs_dict).sample
|
||||
assert full_cond_out is not None
|
||||
|
||||
batch, tokens, _ = cond.shape
|
||||
keeplast_mask = (torch.arange(tokens) == tokens - 1).expand(batch, -1).to(cond.device, torch.bool)
|
||||
keeplast_out = model(**{**inputs_dict, "encoder_attention_mask": keeplast_mask}).sample
|
||||
assert not keeplast_out.allclose(full_cond_out), "a 'keep last token' mask should change the result"
|
||||
|
||||
trunc_mask = torch.zeros(batch, tokens - 1, device=cond.device, dtype=torch.bool)
|
||||
trunc_mask_out = model(**{**inputs_dict, "encoder_attention_mask": trunc_mask}).sample
|
||||
assert trunc_mask_out.allclose(keeplast_out), (
|
||||
"a mask with fewer tokens than condition, will be padded with 'keep' tokens. a 'discard-all' mask missing the final token is thus equivalent to a 'keep last' mask."
|
||||
)
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -963,8 +733,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
assert (sample1 - sample2).abs().max() < 3e-3
|
||||
|
||||
def test_custom_diffusion_save_load(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -984,7 +754,7 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname, safe_serialization=False)
|
||||
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
@@ -1003,8 +773,8 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_custom_diffusion_xformers_on_off(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -1028,28 +798,41 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_pickle(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterMixin):
|
||||
"""IP Adapter tests for UNet2DConditionModel."""
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
def create_ip_adapter_state_dict(self, model):
|
||||
return create_ip_adapter_state_dict(model)
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
batch_size = inputs_dict["encoder_hidden_states"].shape[0]
|
||||
# for ip-adapter image_embeds has shape [batch_size, num_image, embed_dim]
|
||||
cross_attention_dim = getattr(model.config, "cross_attention_dim", 8)
|
||||
image_embeds = floats_tensor((batch_size, 1, cross_attention_dim)).to(torch_device)
|
||||
inputs_dict["added_cond_kwargs"] = {"image_embeds": [image_embeds]}
|
||||
return inputs_dict
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
def test_asymmetrical_unet(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
# Add asymmetry to configs
|
||||
init_dict["transformer_layers_per_block"] = [[3, 2], 1]
|
||||
init_dict["reverse_transformer_layers_per_block"] = [[3, 4], 1]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
output = model(**inputs_dict).sample
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
||||
# Check if input and output shapes are the same
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_ip_adapter(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -1122,8 +905,7 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_ip_adapter_plus(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
@@ -1195,16 +977,185 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
|
||||
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
|
||||
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for UNet2DConditionModel."""
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, subfolder="unet", local_files_only=True)
|
||||
loaded_model = loaded_model.to(torch_device)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
@parameterized.expand(
|
||||
[
|
||||
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
|
||||
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
|
||||
]
|
||||
)
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_subfolder(self, repo_id, variant):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant, subfolder="unet", device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy")
|
||||
loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto")
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy-subfolder")
|
||||
loaded_model = self.model_class.from_pretrained(
|
||||
ckpt_path, local_files_only=True, subfolder="unet", device_map="auto"
|
||||
)
|
||||
new_output = loaded_model(**inputs_dict)
|
||||
|
||||
assert loaded_model
|
||||
assert new_output.sample.shape == (4, 4, 16, 16)
|
||||
|
||||
@require_peft_backend
|
||||
def test_load_attn_procs_raise_warning(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
# forward pass without LoRA
|
||||
with torch.no_grad():
|
||||
non_lora_sample = model(**inputs_dict).sample
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
# forward pass with LoRA
|
||||
with torch.no_grad():
|
||||
lora_sample_1 = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
model.unload_lora()
|
||||
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `load_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
# import to still check for the rest of the stuff.
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with torch.no_grad():
|
||||
lora_sample_2 = model(**inputs_dict).sample
|
||||
|
||||
assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), (
|
||||
"LoRA injected UNet should produce different results."
|
||||
)
|
||||
assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), (
|
||||
"Loading from a saved checkpoint should produce identical results."
|
||||
)
|
||||
|
||||
@require_peft_backend
|
||||
def test_save_attn_procs_raise_warning(self):
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
unet_lora_config = get_unet_lora_config()
|
||||
model.add_adapter(unet_lora_config)
|
||||
|
||||
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertWarns(FutureWarning) as warning:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
|
||||
warning_message = str(warning.warnings[0].message)
|
||||
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
|
||||
|
||||
|
||||
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for UNet2DConditionModel."""
|
||||
class UNet2DConditionModelCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class UNet2DConditionModelLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return UNet2DConditionModelTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
@slow
|
||||
|
||||
@@ -18,44 +18,47 @@ import unittest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import UNet3DConditionModel
|
||||
from diffusers.models import ModelMixin, UNet3DConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
)
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, skip_mps, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNet3DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNet3DConditionModel testing."""
|
||||
class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNet3DConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNet3DConditionModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 4, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (4, 8),
|
||||
"norm_num_groups": 4,
|
||||
"down_block_types": (
|
||||
@@ -70,25 +73,27 @@ class UNet3DConditionTesterConfig(BaseModelTesterConfig):
|
||||
"layers_per_block": 1,
|
||||
"sample_size": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
num_frames = 4
|
||||
sizes = (16, 16)
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
}
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
# Overriding to set `norm_num_groups` needs to be different for this model.
|
||||
def test_forward_with_norm_groups(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
@@ -102,74 +107,39 @@ class TestUNet3DCondition(UNet3DConditionTesterConfig, ModelTesterMixin, UNetTes
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
# Overriding since the UNet3D outputs a different structure.
|
||||
@torch.no_grad()
|
||||
def test_determinism(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
with torch.no_grad():
|
||||
# Warmup pass when using mps (see #372)
|
||||
if torch_device == "mps" and isinstance(model, ModelMixin):
|
||||
model(**self.dummy_input)
|
||||
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first.sample
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second.sample
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
out_2 = out_2[~np.isnan(out_2)]
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
assert max_diff <= 1e-5
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
assert output.shape == output_2.shape, "Shape doesn't match"
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
|
||||
class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
init_dict["attention_head_dim"] = 8
|
||||
@@ -192,3 +162,22 @@ class TestUNet3DConditionAttention(UNet3DConditionTesterConfig, AttentionTesterM
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_feed_forward_chunking(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict["block_out_channels"] = (32, 64)
|
||||
init_dict["norm_num_groups"] = 32
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)[0]
|
||||
|
||||
model.enable_forward_chunking()
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(output.shape, output_2.shape, "Shape doesn't match")
|
||||
assert np.abs(output.cpu() - output_2.cpu()).max() < 1e-2
|
||||
|
||||
@@ -13,42 +13,59 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers import ControlNetXSAdapter, UNet2DConditionModel, UNetControlNetXSModel
|
||||
from diffusers.utils import logging
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, is_flaky, torch_device
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetControlNetXSModel testing."""
|
||||
class UNetControlNetXSModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetControlNetXSModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNetControlNetXSModel
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32) # size of additional, unprocessed image for control-conditioning
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
|
||||
controlnet_cond = floats_tensor((batch_size, *conditioning_image_size)).to(torch_device)
|
||||
conditioning_scale = 1
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"controlnet_cond": controlnet_cond,
|
||||
"conditioning_scale": conditioning_scale,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"sample_size": 16,
|
||||
"down_block_types": ("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
"up_block_types": ("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
@@ -63,23 +80,11 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
"ctrl_max_norm_num_groups": 2,
|
||||
"ctrl_conditioning_embedding_out_channels": (2, 2),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (16, 16)
|
||||
conditioning_image_size = (3, 32, 32)
|
||||
|
||||
return {
|
||||
"sample": floats_tensor((batch_size, num_channels) + sizes).to(torch_device),
|
||||
"timestep": torch.tensor([10]).to(torch_device),
|
||||
"encoder_hidden_states": floats_tensor((batch_size, 4, 8)).to(torch_device),
|
||||
"controlnet_cond": floats_tensor((batch_size, *conditioning_image_size)).to(torch_device),
|
||||
"conditioning_scale": 1,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_unet(self):
|
||||
"""Build the underlying UNet for tests that construct UNetControlNetXSModel from UNet + Adapter."""
|
||||
"""For some tests we also need the underlying UNet. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
return UNet2DConditionModel(
|
||||
block_out_channels=(4, 8),
|
||||
layers_per_block=2,
|
||||
@@ -94,16 +99,10 @@ class UNetControlNetXSTesterConfig(BaseModelTesterConfig):
|
||||
)
|
||||
|
||||
def get_dummy_controlnet_from_unet(self, unet, **kwargs):
|
||||
"""Build the ControlNetXS-Adapter from a UNet."""
|
||||
"""For some tests we also need the underlying ControlNetXS-Adapter. For these, we'll build the UNetControlNetXSModel from the UNet and ControlNetXS-Adapter"""
|
||||
# size_ratio and conditioning_embedding_out_channels chosen to keep model small
|
||||
return ControlNetXSAdapter.from_unet(unet, size_ratio=1, conditioning_embedding_out_channels=(2, 2), **kwargs)
|
||||
|
||||
|
||||
class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel only supports SD/SDXL with norm_num_groups=32
|
||||
pass
|
||||
|
||||
def test_from_unet(self):
|
||||
unet = self.get_dummy_unet()
|
||||
controlnet = self.get_dummy_controlnet_from_unet(unet)
|
||||
@@ -116,7 +115,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
assert torch.equal(model_state_dict[weight_dict_prefix + "." + param_name], param_value)
|
||||
|
||||
# # check unet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
"time_embedding",
|
||||
"conv_in",
|
||||
@@ -153,7 +152,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
assert_equal_weights(u.upsamplers[0], f"up_blocks.{i}.upsamplers")
|
||||
|
||||
# # check controlnet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = {
|
||||
"controlnet_cond_embedding": "controlnet_cond_embedding",
|
||||
"conv_in": "ctrl_conv_in",
|
||||
@@ -194,12 +193,12 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
for p in module.parameters():
|
||||
assert p.requires_grad
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
model = UNetControlNetXSModel(**init_dict)
|
||||
model.freeze_unet_params()
|
||||
|
||||
# # check unet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_unet = [
|
||||
model.base_time_embedding,
|
||||
model.base_conv_in,
|
||||
@@ -237,7 +236,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
assert_frozen(u.upsamplers)
|
||||
|
||||
# # check controlnet
|
||||
# everything except down,mid,up blocks
|
||||
# everything expect down,mid,up blocks
|
||||
modules_from_controlnet = [
|
||||
model.controlnet_cond_embedding,
|
||||
model.ctrl_conv_in,
|
||||
@@ -268,6 +267,16 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
for u in model.up_blocks:
|
||||
assert_unfrozen(u.ctrl_to_base)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@is_flaky
|
||||
def test_forward_no_control(self):
|
||||
unet = self.get_dummy_unet()
|
||||
@@ -278,7 +287,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
unet = unet.to(torch_device)
|
||||
model = model.to(torch_device)
|
||||
|
||||
input_ = self.get_dummy_inputs()
|
||||
input_ = self.dummy_input
|
||||
|
||||
control_specific_input = ["controlnet_cond", "conditioning_scale"]
|
||||
input_for_unet = {k: v for k, v in input_.items() if k not in control_specific_input}
|
||||
@@ -303,7 +312,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
model = model.to(torch_device)
|
||||
model_mix_time = model_mix_time.to(torch_device)
|
||||
|
||||
input_ = self.get_dummy_inputs()
|
||||
input_ = self.dummy_input
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**input_).sample
|
||||
@@ -311,14 +320,7 @@ class TestUNetControlNetXS(UNetControlNetXSTesterConfig, ModelTesterMixin, UNetT
|
||||
|
||||
assert output.shape == output_mix_time.shape
|
||||
|
||||
|
||||
class TestUNetControlNetXSTraining(UNetControlNetXSTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"Transformer2DModel",
|
||||
"UNetMidBlock2DCrossAttn",
|
||||
"ControlNetXSCrossAttnDownBlock2D",
|
||||
"ControlNetXSCrossAttnMidBlock2D",
|
||||
"ControlNetXSCrossAttnUpBlock2D",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
# UNetControlNetXSModel currently only supports StableDiffusion and StableDiffusion-XL, both of which have norm_num_groups fixed at 32. So we don't need to test different values for norm_num_groups.
|
||||
pass
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import UNetSpatioTemporalConditionModel
|
||||
from diffusers.utils import logging
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from ...testing_utils import (
|
||||
@@ -28,34 +28,45 @@ from ...testing_utils import (
|
||||
skip_mps,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import UNetTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@skip_mps
|
||||
class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
"""Base configuration for UNetSpatioTemporalConditionModel testing."""
|
||||
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
model_class = UNetSpatioTemporalConditionModel
|
||||
main_input_name = "sample"
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return UNetSpatioTemporalConditionModel
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 2, 4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def main_input_name(self):
|
||||
return "sample"
|
||||
|
||||
@property
|
||||
def fps(self):
|
||||
return 6
|
||||
@@ -72,8 +83,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
def addition_time_embed_dim(self):
|
||||
return 32
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": (
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
@@ -92,23 +103,8 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
"projection_class_embeddings_input_dim": self.addition_time_embed_dim * 3,
|
||||
"addition_time_embed_dim": self.addition_time_embed_dim,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 2
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_frames, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 1, 32)).to(torch_device)
|
||||
|
||||
return {
|
||||
"sample": noise,
|
||||
"timestep": time_step,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"added_time_ids": self._get_add_time_ids(),
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def _get_add_time_ids(self, do_classifier_free_guidance=True):
|
||||
add_time_ids = [self.fps, self.motion_bucket_id, self.noise_aug_strength]
|
||||
@@ -128,15 +124,43 @@ class UNetSpatioTemporalTesterConfig(BaseModelTesterConfig):
|
||||
|
||||
return add_time_ids
|
||||
|
||||
|
||||
class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, UNetTesterMixin):
|
||||
@pytest.mark.skip("Number of Norm Groups is not configurable")
|
||||
@unittest.skip("Number of Norm Groups is not configurable")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Deprecated functionality")
|
||||
def test_model_attention_slicing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_use_linear_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_simple_projection(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported")
|
||||
def test_model_with_class_embeddings_concat(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
def test_model_with_num_attention_heads_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -149,13 +173,12 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_cross_attention_dim_tuple(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["cross_attention_dim"] = (32, 32)
|
||||
|
||||
@@ -169,13 +192,27 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
assert output is not None
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
num_attention_heads = (8, 16)
|
||||
super().test_gradient_checkpointing_is_applied(
|
||||
expected_set=expected_set, num_attention_heads=num_attention_heads
|
||||
)
|
||||
|
||||
def test_pickle(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["num_attention_heads"] = (8, 16)
|
||||
|
||||
@@ -188,33 +225,3 @@ class TestUNetSpatioTemporal(UNetSpatioTemporalTesterConfig, ModelTesterMixin, U
|
||||
sample_copy = copy.copy(sample)
|
||||
|
||||
assert (sample - sample_copy).abs().max() < 1e-4
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalAttention(UNetSpatioTemporalTesterConfig, AttentionTesterMixin):
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict = self.get_init_dict()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1.processor.__class__.__name__
|
||||
== "XFormersAttnProcessor"
|
||||
), "xformers is not enabled"
|
||||
|
||||
|
||||
class TestUNetSpatioTemporalTraining(UNetSpatioTemporalTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"TransformerSpatioTemporalModel",
|
||||
"CrossAttnDownBlockSpatioTemporal",
|
||||
"DownBlockSpatioTemporal",
|
||||
"UpBlockSpatioTemporal",
|
||||
"CrossAttnUpBlockSpatioTemporal",
|
||||
"UNetMidBlockSpatioTemporal",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
74
tests/others/test_flashpack.py
Normal file
74
tests/others/test_flashpack.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from diffusers.models.auto_model import AutoModel
|
||||
|
||||
from ..testing_utils import is_torch_available, require_flashpack, require_torch_gpu
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
class FlashPackTests(unittest.TestCase):
|
||||
model_id: str = "hf-internal-testing/tiny-flux-pipe"
|
||||
|
||||
@require_flashpack
|
||||
def test_save_load_model(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
self.assertTrue((pathlib.Path(temp_dir) / "model.flashpack").exists())
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True)
|
||||
|
||||
@require_flashpack
|
||||
def test_save_load_pipeline(self):
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(self.model_id)
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
pipeline.save_pretrained(temp_dir, use_flashpack=True)
|
||||
self.assertTrue((pathlib.Path(temp_dir) / "transformer" / "model.flashpack").exists())
|
||||
self.assertTrue((pathlib.Path(temp_dir) / "vae" / "model.flashpack").exists())
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(temp_dir, use_flashpack=True)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_flashpack
|
||||
def test_load_model_device_str(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "cuda"})
|
||||
self.assertTrue(model.device.type == "cuda")
|
||||
|
||||
@require_torch_gpu
|
||||
@require_flashpack
|
||||
def test_load_model_device(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": torch.device("cuda")})
|
||||
self.assertTrue(model.device.type == "cuda")
|
||||
|
||||
@require_flashpack
|
||||
def test_load_model_device_auto(self):
|
||||
model = AutoModel.from_pretrained(self.model_id, subfolder="transformer")
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
model.save_pretrained(temp_dir, use_flashpack=True)
|
||||
with self.assertRaises(ValueError):
|
||||
model = AutoModel.from_pretrained(temp_dir, use_flashpack=True, device_map={"": "auto"})
|
||||
@@ -34,6 +34,7 @@ from diffusers.utils.import_utils import (
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_compel_available,
|
||||
is_flashpack_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_kernels_available,
|
||||
@@ -737,6 +738,13 @@ def require_accelerate(test_case):
|
||||
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
|
||||
|
||||
|
||||
def require_flashpack(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires flashpack. These tests are skipped when flashpack isn't installed.
|
||||
"""
|
||||
return pytest.mark.skipif(not is_flashpack_available(), reason="test requires flashpack")(test_case)
|
||||
|
||||
|
||||
def require_peft_version_greater(peft_version):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
|
||||
|
||||
Reference in New Issue
Block a user