mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[Z-Image] various small changes, Z-Image transformer tests, etc. (#12741)
* start zimage model tests.
* up
* up
* up
* up
* up
* up
* up
* up
* up
* up
* up
* up
* Revert "up"
This reverts commit bca3e27c96.
* expand upon compilation failure reason.
* Update tests/models/transformers/test_models_transformer_z_image.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
* reinitialize the padding tokens to ones to prevent NaN problems.
* updates
* up
* skipping ZImage DiT tests
* up
* up
---------
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from ...models.modeling_utils import ModelMixin
|
||||
from ...models.normalization import RMSNorm
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
@@ -39,17 +40,9 @@ class TimestepEmbedder(nn.Module):
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
mid_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.Linear(frequency_embedding_size, mid_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
mid_size,
|
||||
out_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.Linear(mid_size, out_size, bias=True),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
@@ -211,9 +204,7 @@ class ZImageTransformerBlock(nn.Module):
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
||||
)
|
||||
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -230,33 +221,19 @@ class ZImageTransformerBlock(nn.Module):
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
attention_mask=attn_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
|
||||
return x
|
||||
|
||||
@@ -404,10 +381,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
]
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
@@ -494,11 +468,8 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
)
|
||||
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat(
|
||||
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat)
|
||||
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
@@ -564,6 +535,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
@@ -672,4 +644,7 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
|
||||
return x, {}
|
||||
if not return_dict:
|
||||
return (x,)
|
||||
|
||||
return Transformer2DModelOutput(sample=x)
|
||||
|
||||
@@ -525,9 +525,7 @@ class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMix
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
|
||||
@@ -15,17 +15,13 @@
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImagePipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
|
||||
|
||||
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend
|
||||
from ..testing_utils import floats_tensor, is_peft_available, require_peft_backend, skip_mps, torch_device
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@@ -34,13 +30,9 @@ if is_peft_available():
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa: E402
|
||||
|
||||
|
||||
@unittest.skip(
|
||||
"ZImage LoRA tests are skipped due to non-deterministic behavior from complex64 RoPE operations "
|
||||
"and torch.empty padding tokens. LoRA functionality works correctly with real models."
|
||||
)
|
||||
@require_peft_backend
|
||||
class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = ZImagePipeline
|
||||
@@ -127,6 +119,12 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
tokenizer = Qwen2Tokenizer.from_pretrained(self.tokenizer_id)
|
||||
|
||||
transformer = self.transformer_cls(**self.transformer_kwargs)
|
||||
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
|
||||
# This can cause NaN data values in our testing environment. Fixating them
|
||||
# helps prevent that issue.
|
||||
with torch.no_grad():
|
||||
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
|
||||
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
|
||||
vae = self.vae_cls(**self.vae_kwargs)
|
||||
|
||||
if scheduler_cls is None:
|
||||
@@ -161,3 +159,127 @@ class ZImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
}
|
||||
|
||||
return pipeline_components, text_lora_config, denoiser_lora_config
|
||||
|
||||
def test_correct_lora_configs_with_different_ranks(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
|
||||
lora_output_same_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
denoiser = pipe.unet if self.unet_kwargs is not None else pipe.transformer
|
||||
for name, _ in denoiser.named_modules():
|
||||
if "to_k" in name and "attention" in name and "lora" not in name:
|
||||
module_name_to_rank_update = name.replace(".base_layer.", ".")
|
||||
break
|
||||
|
||||
# change the rank_pattern
|
||||
updated_rank = denoiser_lora_config.r * 2
|
||||
denoiser_lora_config.rank_pattern = {module_name_to_rank_update: updated_rank}
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
updated_rank_pattern = pipe.transformer.peft_config["adapter-1"].rank_pattern
|
||||
|
||||
self.assertTrue(updated_rank_pattern == {module_name_to_rank_update: updated_rank})
|
||||
|
||||
lora_output_diff_rank = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_rank, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
pipe.transformer.delete_adapters("adapter-1")
|
||||
|
||||
# similarly change the alpha_pattern
|
||||
updated_alpha = denoiser_lora_config.lora_alpha * 2
|
||||
denoiser_lora_config.alpha_pattern = {module_name_to_rank_update: updated_alpha}
|
||||
|
||||
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(
|
||||
pipe.transformer.peft_config["adapter-1"].alpha_pattern == {module_name_to_rank_update: updated_alpha}
|
||||
)
|
||||
|
||||
lora_output_diff_alpha = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
self.assertTrue(not np.allclose(original_output, lora_output_diff_alpha, atol=1e-3, rtol=1e-3))
|
||||
self.assertTrue(not np.allclose(lora_output_diff_alpha, lora_output_same_rank, atol=1e-3, rtol=1e-3))
|
||||
|
||||
@skip_mps
|
||||
def test_lora_fuse_nan(self):
|
||||
components, _, denoiser_lora_config = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
|
||||
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
|
||||
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
|
||||
|
||||
# corrupt one LoRA weight with `inf` values
|
||||
with torch.no_grad():
|
||||
possible_tower_names = ["noise_refiner"]
|
||||
filtered_tower_names = [
|
||||
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
|
||||
]
|
||||
for tower_name in filtered_tower_names:
|
||||
transformer_tower = getattr(pipe.transformer, tower_name)
|
||||
transformer_tower[0].attention.to_q.lora_A["adapter-1"].weight += float("inf")
|
||||
|
||||
# with `safe_fusing=True` we should see an Error
|
||||
with self.assertRaises(ValueError):
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
|
||||
|
||||
# without we should not see an error, but every image will be black
|
||||
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
|
||||
out = pipe(**inputs)[0]
|
||||
|
||||
self.assertTrue(np.isnan(out).all())
|
||||
|
||||
def test_lora_scale_kwargs_match_fusion(self):
|
||||
super().test_lora_scale_kwargs_match_fusion(5e-2, 5e-2)
|
||||
|
||||
@unittest.skip("Needs to be debugged.")
|
||||
def test_set_adapters_match_attention_kwargs(self):
|
||||
super().test_set_adapters_match_attention_kwargs()
|
||||
|
||||
@unittest.skip("Needs to be debugged.")
|
||||
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
|
||||
super().test_simple_inference_with_text_denoiser_lora_and_scale()
|
||||
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported in ZImage.")
|
||||
def test_modify_padding_mode(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_partial_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_and_scale(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_fused(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Text encoder LoRA is not supported in ZImage.")
|
||||
def test_simple_inference_with_text_lora_save_load(self):
|
||||
pass
|
||||
|
||||
@@ -47,6 +47,7 @@ from diffusers.models.attention_processor import (
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.models.auto_model import AutoModel
|
||||
from diffusers.models.modeling_outputs import BaseOutput
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@@ -108,6 +109,11 @@ def check_if_lora_correctly_set(model) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def normalize_output(out):
|
||||
out0 = out[0] if isinstance(out, (BaseOutput, tuple)) else out
|
||||
return torch.stack(out0) if isinstance(out0, list) else out0
|
||||
|
||||
|
||||
# Will be run via run_test_in_subprocess
|
||||
def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
|
||||
error = None
|
||||
@@ -536,6 +542,9 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
image = normalize_output(image)
|
||||
new_image = normalize_output(new_image)
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
|
||||
|
||||
@@ -780,6 +789,9 @@ class ModelTesterMixin:
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image.to_tuple()[0]
|
||||
|
||||
image = normalize_output(image)
|
||||
new_image = normalize_output(new_image)
|
||||
|
||||
max_diff = (image - new_image).abs().max().item()
|
||||
self.assertLessEqual(max_diff, expected_max_diff, "Models give different forward passes")
|
||||
|
||||
@@ -842,6 +854,9 @@ class ModelTesterMixin:
|
||||
if isinstance(second, dict):
|
||||
second = second.to_tuple()[0]
|
||||
|
||||
first = normalize_output(first)
|
||||
second = normalize_output(second)
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
out_1 = out_1[~np.isnan(out_1)]
|
||||
@@ -860,11 +875,15 @@ class ModelTesterMixin:
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
if isinstance(output, list):
|
||||
output = torch.stack(output)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
if isinstance(input_tensor, list):
|
||||
input_tensor = torch.stack(input_tensor)
|
||||
|
||||
if expected_output_shape is None:
|
||||
expected_shape = input_tensor.shape
|
||||
@@ -898,11 +917,15 @@ class ModelTesterMixin:
|
||||
|
||||
if isinstance(output_1, dict):
|
||||
output_1 = output_1.to_tuple()[0]
|
||||
if isinstance(output_1, list):
|
||||
output_1 = torch.stack(output_1)
|
||||
|
||||
output_2 = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(output_2, dict):
|
||||
output_2 = output_2.to_tuple()[0]
|
||||
if isinstance(output_2, list):
|
||||
output_2 = torch.stack(output_2)
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
@@ -1138,6 +1161,8 @@ class ModelTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_no_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
if isinstance(output_no_lora, list):
|
||||
output_no_lora = torch.stack(output_no_lora)
|
||||
|
||||
denoiser_lora_config = LoraConfig(
|
||||
r=rank,
|
||||
@@ -1151,6 +1176,8 @@ class ModelTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora = model(**inputs_dict, return_dict=False)[0]
|
||||
if isinstance(outputs_with_lora, list):
|
||||
outputs_with_lora = torch.stack(outputs_with_lora)
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@@ -1175,6 +1202,8 @@ class ModelTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0]
|
||||
if isinstance(outputs_with_lora_2, list):
|
||||
outputs_with_lora_2 = torch.stack(outputs_with_lora_2)
|
||||
|
||||
self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4))
|
||||
@@ -1296,31 +1325,35 @@ class ModelTesterMixin:
|
||||
def test_cpu_offload(self):
|
||||
if self.model_class._no_split_modules is None:
|
||||
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
|
||||
|
||||
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**config).eval()
|
||||
|
||||
model = model.to(torch_device)
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
# We test several splits of sizes to make sure it works.
|
||||
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.cpu().save_pretrained(tmp_dir)
|
||||
|
||||
for max_size in max_gpu_sizes:
|
||||
max_memory = {0: max_size, "cpu": model_size * 2}
|
||||
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
|
||||
|
||||
# Making sure part of the model will actually end up offloaded
|
||||
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
|
||||
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_without_safetensors(self):
|
||||
@@ -1333,6 +1366,7 @@ class ModelTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
max_size = int(self.model_split_percents[0] * model_size)
|
||||
@@ -1352,8 +1386,8 @@ class ModelTesterMixin:
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_disk_offload_with_safetensors(self):
|
||||
@@ -1366,6 +1400,7 @@ class ModelTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_sizes(model)[""]
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -1380,8 +1415,9 @@ class ModelTesterMixin:
|
||||
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
|
||||
torch.manual_seed(0)
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
@require_torch_multi_accelerator
|
||||
def test_model_parallelism(self):
|
||||
@@ -1422,6 +1458,7 @@ class ModelTesterMixin:
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
@@ -1443,8 +1480,9 @@ class ModelTesterMixin:
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_sharded_checkpoints_with_variant(self):
|
||||
@@ -1454,6 +1492,7 @@ class ModelTesterMixin:
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
@@ -1481,8 +1520,9 @@ class ModelTesterMixin:
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_sharded_checkpoints_with_parallel_loading(self):
|
||||
@@ -1492,6 +1532,7 @@ class ModelTesterMixin:
|
||||
model = model.to(torch_device)
|
||||
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
@@ -1515,7 +1556,9 @@ class ModelTesterMixin:
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
# set to no.
|
||||
os.environ["HF_ENABLE_PARALLEL_LOADING"] = "no"
|
||||
|
||||
@@ -1529,6 +1572,7 @@ class ModelTesterMixin:
|
||||
|
||||
torch.manual_seed(0)
|
||||
base_output = model(**inputs_dict)
|
||||
base_normalized_output = normalize_output(base_output)
|
||||
|
||||
model_size = compute_module_persistent_sizes(model)[""]
|
||||
max_shard_size = int((model_size * 0.75) / (2**10)) # Convert to KB as these test models are small.
|
||||
@@ -1549,7 +1593,9 @@ class ModelTesterMixin:
|
||||
if "generator" in inputs_dict:
|
||||
_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
new_output = new_model(**inputs_dict)
|
||||
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
|
||||
new_normalized_output = normalize_output(new_output)
|
||||
|
||||
self.assertTrue(torch.allclose(base_normalized_output, new_normalized_output, atol=1e-5))
|
||||
|
||||
# This test is okay without a GPU because we're not running any execution. We're just serializing
|
||||
# and check if the resultant files are following an expected format.
|
||||
@@ -1629,7 +1675,9 @@ class ModelTesterMixin:
|
||||
model = self.model_class(**config)
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
base_slice = model(**inputs_dict)[0].detach().flatten().cpu().numpy()
|
||||
base_slice = model(**inputs_dict)[0]
|
||||
base_slice = normalize_output(base_slice)
|
||||
base_slice = base_slice.detach().flatten().cpu().numpy()
|
||||
|
||||
def check_linear_dtype(module, storage_dtype, compute_dtype):
|
||||
patterns_to_check = DEFAULT_SKIP_MODULES_PATTERN
|
||||
@@ -1655,7 +1703,9 @@ class ModelTesterMixin:
|
||||
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
|
||||
|
||||
check_linear_dtype(model, storage_dtype, compute_dtype)
|
||||
output = model(**inputs_dict)[0].float().flatten().detach().cpu().numpy()
|
||||
output = model(**inputs_dict)[0]
|
||||
output = normalize_output(output)
|
||||
output = output.float().flatten().detach().cpu().numpy()
|
||||
|
||||
# The precision test is not very important for fast tests. In most cases, the outputs will not be the same.
|
||||
# We just want to make sure that the layerwise casting is working as expected.
|
||||
@@ -1716,6 +1766,12 @@ class ModelTesterMixin:
|
||||
@parameterized.expand([False, True])
|
||||
@require_torch_accelerator
|
||||
def test_group_offloading(self, record_stream):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_group_offloading" in cls.__dict__ and cls is not ModelTesterMixin:
|
||||
# Skip this test if it is overwritten by child class. We need to do this because parameterized
|
||||
# materializes the test methods on invocation which cannot be overridden.
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
@@ -1738,21 +1794,25 @@ class ModelTesterMixin:
|
||||
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = run_forward(model)
|
||||
output_without_group_offloading = normalize_output(output_without_group_offloading)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1)
|
||||
output_with_group_offloading1 = run_forward(model)
|
||||
output_with_group_offloading1 = normalize_output(output_with_group_offloading1)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=1, non_blocking=True)
|
||||
output_with_group_offloading2 = run_forward(model)
|
||||
output_with_group_offloading2 = normalize_output(output_with_group_offloading2)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.enable_group_offload(torch_device, offload_type="leaf_level")
|
||||
output_with_group_offloading3 = run_forward(model)
|
||||
output_with_group_offloading3 = normalize_output(output_with_group_offloading3)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -1760,6 +1820,7 @@ class ModelTesterMixin:
|
||||
torch_device, offload_type="leaf_level", use_stream=True, record_stream=record_stream
|
||||
)
|
||||
output_with_group_offloading4 = run_forward(model)
|
||||
output_with_group_offloading4 = normalize_output(output_with_group_offloading4)
|
||||
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
|
||||
@@ -1799,6 +1860,12 @@ class ModelTesterMixin:
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def test_group_offloading_with_disk(self, offload_type, record_stream, atol=1e-5):
|
||||
for cls in inspect.getmro(self.__class__):
|
||||
if "test_group_offloading_with_disk" in cls.__dict__ and cls is not ModelTesterMixin:
|
||||
# Skip this test if it is overwritten by child class. We need to do this because parameterized
|
||||
# materializes the test methods on invocation which cannot be overridden.
|
||||
pytest.skip("Model does not support group offloading with disk yet.")
|
||||
|
||||
if not self.model_class._supports_group_offloading:
|
||||
pytest.skip("Model does not support group offloading.")
|
||||
|
||||
@@ -1821,6 +1888,7 @@ class ModelTesterMixin:
|
||||
model.eval()
|
||||
model.to(torch_device)
|
||||
output_without_group_offloading = _run_forward(model, inputs_dict)
|
||||
output_without_group_offloading = normalize_output(output_without_group_offloading)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
@@ -1856,6 +1924,7 @@ class ModelTesterMixin:
|
||||
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
|
||||
|
||||
output_with_group_offloading = _run_forward(model, inputs_dict)
|
||||
output_with_group_offloading = normalize_output(output_with_group_offloading)
|
||||
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=atol))
|
||||
|
||||
def test_auto_model(self, expected_max_diff=5e-5):
|
||||
@@ -1889,10 +1958,17 @@ class ModelTesterMixin:
|
||||
output_original = model(**inputs_dict)
|
||||
output_auto = auto_model(**inputs_dict)
|
||||
|
||||
if isinstance(output_original, dict):
|
||||
output_original = output_original.to_tuple()[0]
|
||||
if isinstance(output_auto, dict):
|
||||
output_auto = output_auto.to_tuple()[0]
|
||||
if isinstance(output_original, dict):
|
||||
output_original = output_original.to_tuple()[0]
|
||||
if isinstance(output_auto, dict):
|
||||
output_auto = output_auto.to_tuple()[0]
|
||||
|
||||
if isinstance(output_original, list):
|
||||
output_original = torch.stack(output_original)
|
||||
if isinstance(output_auto, list):
|
||||
output_auto = torch.stack(output_auto)
|
||||
|
||||
output_original, output_auto = output_original.float(), output_auto.float()
|
||||
|
||||
max_diff = (output_original - output_auto).abs().max().item()
|
||||
self.assertLessEqual(
|
||||
@@ -2083,6 +2159,8 @@ class TorchCompileTesterMixin:
|
||||
recompile_limit = 1
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
elif self.model_class.__name__ == "ZImageTransformer2DModel":
|
||||
recompile_limit = 3
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
@@ -2184,7 +2262,6 @@ class LoraHotSwappingForModelTesterMixin:
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def get_lora_config(self, lora_rank, lora_alpha, target_modules):
|
||||
# from diffusers test_models_unet_2d_condition.py
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
|
||||
171
tests/models/transformers/test_models_transformer_z_image.py
Normal file
171
tests/models/transformers/test_models_transformer_z_image.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import ZImageTransformer2DModel
|
||||
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
|
||||
# Cannot use enable_full_determinism() which sets it to True
|
||||
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
torch.use_deterministic_algorithms(False)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
if hasattr(torch.backends, "cuda"):
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
IS_GITHUB_ACTIONS,
|
||||
reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
|
||||
)
|
||||
class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = ZImageTransformer2DModel
|
||||
main_input_name = "x"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.9, 0.9, 0.9]
|
||||
|
||||
def prepare_dummy_input(self, height=16, width=16):
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)]
|
||||
encoder_hidden_states = [
|
||||
torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size)
|
||||
]
|
||||
timestep = torch.tensor([0.0]).to(torch_device)
|
||||
|
||||
return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"all_patch_size": (2,),
|
||||
"all_f_patch_size": (1,),
|
||||
"in_channels": 16,
|
||||
"dim": 16,
|
||||
"n_layers": 1,
|
||||
"n_refiner_layers": 1,
|
||||
"n_heads": 1,
|
||||
"n_kv_heads": 2,
|
||||
"qk_norm": True,
|
||||
"cap_feat_dim": 16,
|
||||
"rope_theta": 256.0,
|
||||
"t_scale": 1000.0,
|
||||
"axes_dims": [8, 4, 4],
|
||||
"axes_lens": [256, 32, 32],
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"ZImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_training(self):
|
||||
super().test_training()
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_ema_training(self):
|
||||
super().test_ema_training()
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
super().test_effective_gradient_checkpointing()
|
||||
|
||||
@unittest.skip(
|
||||
"Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
|
||||
)
|
||||
def test_layerwise_casting_training(self):
|
||||
super().test_layerwise_casting_training()
|
||||
|
||||
@unittest.skip("Test is not supported for handling main inputs that are lists.")
|
||||
def test_outputs_equivalence(self):
|
||||
super().test_outputs_equivalence()
|
||||
|
||||
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
|
||||
def test_group_offloading(self):
|
||||
super().test_group_offloading()
|
||||
|
||||
@unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
|
||||
def test_group_offloading_with_disk(self):
|
||||
super().test_group_offloading_with_disk()
|
||||
|
||||
|
||||
class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = ZImageTransformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
@unittest.skip(
|
||||
"The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
|
||||
)
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
@unittest.skip("Fullgraph AoT is broken")
|
||||
def test_compile_works_with_aot(self):
|
||||
super().test_compile_works_with_aot()
|
||||
|
||||
@unittest.skip("Fullgraph is broken")
|
||||
def test_compile_on_different_shapes(self):
|
||||
super().test_compile_on_different_shapes()
|
||||
@@ -20,12 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import Qwen2Tokenizer, Qwen3Config, Qwen3Model
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
ZImagePipeline,
|
||||
ZImageTransformer2DModel,
|
||||
)
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, ZImagePipeline, ZImageTransformer2DModel
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
@@ -106,6 +101,12 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
axes_dims=[8, 4, 4],
|
||||
axes_lens=[256, 32, 32],
|
||||
)
|
||||
# `x_pad_token` and `cap_pad_token` are initialized with `torch.empty`.
|
||||
# This can cause NaN data values in our testing environment. Fixating them
|
||||
# helps prevent that issue.
|
||||
with torch.no_grad():
|
||||
transformer.x_pad_token.copy_(torch.ones_like(transformer.x_pad_token.data))
|
||||
transformer.cap_pad_token.copy_(torch.ones_like(transformer.cap_pad_token.data))
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
@@ -183,7 +184,7 @@ class ZImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([0.4521, 0.4512, 0.4693, 0.5115, 0.5250, 0.5271, 0.4776, 0.4688, 0.2765, 0.2164, 0.5656, 0.6909, 0.3831, 0.5431, 0.5493, 0.4732])
|
||||
expected_slice = torch.tensor([0.4622, 0.4532, 0.4714, 0.5087, 0.5371, 0.5405, 0.4492, 0.4479, 0.2984, 0.2783, 0.5409, 0.6577, 0.3952, 0.5524, 0.5262, 0.453])
|
||||
# fmt: on
|
||||
|
||||
generated_slice = generated_image.flatten()
|
||||
|
||||
Reference in New Issue
Block a user