mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-11 13:25:20 +08:00
Compare commits
2 Commits
wan-test-r
...
qwen-test-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffdfe28983 | ||
|
|
31ed009706 |
@@ -2321,14 +2321,8 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in original_state_dict.keys():
|
||||
if key.startswith("single_blocks."):
|
||||
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
|
||||
elif key.startswith("double_blocks."):
|
||||
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if not attn.is_cross_attention:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
@@ -219,10 +219,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -230,7 +227,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if not self.is_cross_attention:
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
|
||||
@@ -42,7 +42,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if not attn.is_cross_attention:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
@@ -214,10 +214,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -225,7 +222,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if not self.is_cross_attention:
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
|
||||
@@ -54,7 +54,7 @@ def _get_qkv_projections(attn: "WanAttention", hidden_states: torch.Tensor, enco
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if not attn.is_cross_attention:
|
||||
if attn.cross_attention_dim_head is None:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
@@ -502,16 +502,13 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-6,
|
||||
cross_attention_dim_head: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.cross_attention_head_dim = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
self.use_bias = bias
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
# 1. Pre-Attention Norms for the hidden_states (video latents) and encoder_hidden_states (motion vector).
|
||||
# NOTE: this is not used in "vanilla" WanAttention
|
||||
@@ -519,10 +516,10 @@ class WanAnimateFaceBlockCrossAttention(nn.Module, AttentionModuleMixin):
|
||||
self.pre_norm_kv = nn.LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
# 2. QKV and Output Projections
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=bias)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=bias)
|
||||
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=bias)
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.Linear(self.inner_dim, dim, bias=True)
|
||||
|
||||
# 3. QK Norm
|
||||
# NOTE: this is applied after the reshape, so only over dim_head rather than dim_head * heads
|
||||
@@ -685,10 +682,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
@@ -696,7 +690,7 @@ class WanAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if not self.is_cross_attention:
|
||||
if self.cross_attention_dim_head is None:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
|
||||
@@ -76,7 +76,6 @@ class WanVACETransformerBlock(nn.Module):
|
||||
eps=eps,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
processor=WanAttnProcessor(),
|
||||
is_cross_attention=True,
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
@@ -179,7 +178,6 @@ class WanVACETransformer3DModel(
|
||||
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
||||
@@ -41,7 +41,7 @@ class GGUFQuantizer(DiffusersQuantizer):
|
||||
|
||||
self.compute_dtype = quantization_config.compute_dtype
|
||||
self.pre_quantized = quantization_config.pre_quantized
|
||||
self.modules_to_not_convert = quantization_config.modules_to_not_convert or []
|
||||
self.modules_to_not_convert = quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
@@ -446,17 +446,16 @@ class ModelTesterMixin:
|
||||
torch_device not in ["cuda", "xpu"],
|
||||
reason="float16 and bfloat16 can only be used with an accelerator",
|
||||
)
|
||||
def test_keep_in_fp32_modules(self, tmp_path):
|
||||
def test_keep_in_fp32_modules(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
fp32_modules = model._keep_in_fp32_modules
|
||||
|
||||
if fp32_modules is None or len(fp32_modules) == 0:
|
||||
pytest.skip("Model does not have _keep_in_fp32_modules defined.")
|
||||
|
||||
# Save the model and reload with float16 dtype
|
||||
# _keep_in_fp32_modules is only enforced during from_pretrained loading
|
||||
model.save_pretrained(tmp_path)
|
||||
model = self.model_class.from_pretrained(tmp_path, torch_dtype=torch.float16).to(torch_device)
|
||||
# Test with float16
|
||||
model.to(torch_device)
|
||||
model.to(torch.float16)
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in fp32_modules):
|
||||
@@ -471,7 +470,7 @@ class ModelTesterMixin:
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
@torch.no_grad()
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype, atol=1e-4, rtol=0):
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
fp32_modules = model._keep_in_fp32_modules or []
|
||||
@@ -491,6 +490,10 @@ class ModelTesterMixin:
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
output_loaded = model_loaded(**inputs, return_dict=False)[0]
|
||||
|
||||
self._check_dtype_inference_output(output, output_loaded, dtype)
|
||||
|
||||
def _check_dtype_inference_output(self, output, output_loaded, dtype, atol=1e-4, rtol=0):
|
||||
"""Check dtype inference output with configurable tolerance."""
|
||||
assert_tensors_close(
|
||||
output, output_loaded, atol=atol, rtol=rtol, msg=f"Loaded model output differs for {dtype}"
|
||||
)
|
||||
|
||||
@@ -176,7 +176,15 @@ class QuantizationTesterMixin:
|
||||
model_quantized = self._create_quantized_model(config_kwargs)
|
||||
model_quantized.to(torch_device)
|
||||
|
||||
# Get model dtype from first parameter
|
||||
model_dtype = next(model_quantized.parameters()).dtype
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
# Cast inputs to model dtype
|
||||
inputs = {
|
||||
k: v.to(model_dtype) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model_quantized(**inputs, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
@@ -221,8 +229,6 @@ class QuantizationTesterMixin:
|
||||
init_lora_weights=False,
|
||||
)
|
||||
model.add_adapter(lora_config)
|
||||
# Move LoRA adapter weights to device (they default to CPU)
|
||||
model.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
output = model(**inputs, return_dict=False)[0]
|
||||
@@ -1015,6 +1021,9 @@ class GGUFTesterMixin(GGUFConfigMixin, QuantizationTesterMixin):
|
||||
"""Test that dequantize() works correctly."""
|
||||
self._test_dequantize({"compute_dtype": torch.bfloat16})
|
||||
|
||||
def test_gguf_quantized_layers(self):
|
||||
self._test_quantized_layers({"compute_dtype": torch.bfloat16})
|
||||
|
||||
|
||||
@is_quantization
|
||||
@is_modelopt
|
||||
|
||||
@@ -13,49 +13,87 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return QwenImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 7
|
||||
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
@@ -70,29 +108,12 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
@@ -104,55 +125,56 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
assert isinstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2 # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 5 tokens are valid (seq_len=8)
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
# Max valid position is 7 (last token), so per_sample_len should be 8
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5 # 5 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(rope_text_seq_len_none, int)
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0, 0])"""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
# Pattern: [True, False, True, False, True, False, False, False] (seq_len=8)
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
@@ -160,21 +182,22 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
assert int(per_sample_len.max().item()) == 5
|
||||
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(inferred_rope_len, int)
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
@@ -186,18 +209,24 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
# Verify a FutureWarning was raised
|
||||
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(future_warnings[0].message)
|
||||
assert "txt_seq_lens" in warning_message
|
||||
assert "deprecated" in warning_message
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
@@ -208,7 +237,7 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
@@ -220,11 +249,11 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 7
|
||||
text_seq_len = 8
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
@@ -262,24 +291,104 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
assert output.sample.shape[1] == hidden_states.shape[1]
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for QwenImage Transformer."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for QwenImage Transformer."""
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
@@ -300,13 +409,13 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
@@ -320,8 +429,8 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
@@ -342,8 +451,16 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
||||
|
||||
|
||||
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for QwenImage Transformer."""
|
||||
|
||||
@@ -12,57 +12,57 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanTransformer3DModel
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-wan22-transformer"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
@@ -76,160 +76,16 @@ class WanTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan Transformer 3D."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
|
||||
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan Transformer 3D."""
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan Transformer 3D."""
|
||||
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
|
||||
|
||||
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
return super()._create_quantized_model(
|
||||
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan I2V model dimensions.
|
||||
|
||||
Wan 2.2 I2V: in_channels=36, text_dim=4096
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def _create_quantized_model(self, config_kwargs=None, **extra_kwargs):
|
||||
return super()._create_quantized_model(
|
||||
config_kwargs, config="Wan-AI/Wan2.2-I2V-A14B-Diffusers", subfolder="transformer", **extra_kwargs
|
||||
)
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan I2V model dimensions.
|
||||
|
||||
Wan 2.2 I2V: in_channels=36, text_dim=4096
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -12,62 +12,76 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import WanAnimateTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanAnimateTransformer3DModel
|
||||
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanAnimateTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-wan-animate-transformer"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
clip_seq_len = 12
|
||||
clip_dim = 16
|
||||
|
||||
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
|
||||
face_height = 16 # Should be square and match `motion_encoder_size` below
|
||||
face_width = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
|
||||
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_image": clip_ref_features,
|
||||
"pose_hidden_states": pose_latents,
|
||||
"face_pixel_values": face_pixel_values,
|
||||
}
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
# Output has fewer channels than input (4 vs 12)
|
||||
return (4, 21, 16, 16)
|
||||
def input_shape(self):
|
||||
return (12, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (12, 21, 16, 16)
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
|
||||
# contain the vast majority of the parameters in the test model
|
||||
channel_sizes = {"4": 16, "8": 16, "16": 16}
|
||||
|
||||
return {
|
||||
init_dict = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
@@ -91,219 +105,22 @@ class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
"face_encoder_num_heads": 2,
|
||||
"inject_face_latents_blocks": 2,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
clip_seq_len = 12
|
||||
clip_dim = 16
|
||||
|
||||
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
|
||||
face_height = 16 # Should be square and match `motion_encoder_size`
|
||||
face_width = 16
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(batch_size, clip_seq_len, clip_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(batch_size, 3, inference_segment_length, face_height, face_width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def test_output(self):
|
||||
# Override test_output because the transformer output is expected to have less channels
|
||||
# than the main transformer input.
|
||||
expected_output_shape = (1, 4, 21, 16, 16)
|
||||
super().test_output(expected_output_shape=expected_output_shape)
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol (~1e-2) to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan Animate Transformer 3D."""
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanAnimateTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan Animate Transformer 3D."""
|
||||
# Override test_output because the transformer output is expected to have less channels than the main transformer
|
||||
# input.
|
||||
def test_output(self):
|
||||
expected_output_shape = (1, 4, 21, 16, 16)
|
||||
super().test_output(expected_output_shape=expected_output_shape)
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan Animate Transformer 3D."""
|
||||
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanAnimateTransformer3DModel
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
# Skip: F.pad with mode="replicate" in WanAnimateFaceEncoder triggers importlib.import_module
|
||||
# internally, which dynamo doesn't support tracing through.
|
||||
pytest.skip("F.pad with replicate mode triggers unsupported import in torch.compile")
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan Animate model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan Animate model dimensions.
|
||||
|
||||
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan Animate model dimensions.
|
||||
|
||||
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import WanVACETransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanVACETransformer3DModel
|
||||
|
||||
@property
|
||||
def pretrained_model_name_or_path(self):
|
||||
return "hf-internal-testing/tiny-wan-vace-transformer"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
|
||||
return {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"text_dim": 32,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 4,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
"vace_layers": [0, 2],
|
||||
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 32
|
||||
sequence_length = 12
|
||||
|
||||
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
|
||||
vace_in_channels = 48
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(batch_size, vace_in_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
|
||||
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
|
||||
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
|
||||
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
|
||||
pytest.skip("Tolerance requirements too high for meaningful test")
|
||||
|
||||
def test_model_parallelism(self, tmp_path):
|
||||
# Skip: Device mismatch between cuda:0 and cuda:1 in VACE control flow
|
||||
pytest.skip("Model parallelism not yet supported for WanVACE")
|
||||
|
||||
|
||||
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan VACE Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanVACETransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan VACE Transformer 3D."""
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
# WanVACE has two block types (WanTransformerBlock and WanVACETransformerBlock),
|
||||
# so we need recompile_limit=2 instead of the default 1.
|
||||
import torch._dynamo
|
||||
import torch._inductor.utils
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=2),
|
||||
):
|
||||
_ = model(**inputs_dict)
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.float16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the tiny Wan VACE model dimensions."""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan VACE model dimensions.
|
||||
|
||||
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan VACE model dimensions.
|
||||
|
||||
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
@@ -168,7 +168,7 @@ def assert_tensors_close(
|
||||
max_diff = abs_diff.max().item()
|
||||
|
||||
flat_idx = abs_diff.argmax().item()
|
||||
max_idx = tuple(idx.item() for idx in torch.unravel_index(torch.tensor(flat_idx), actual.shape))
|
||||
max_idx = tuple(torch.unravel_index(torch.tensor(flat_idx), actual.shape).tolist())
|
||||
|
||||
threshold = atol + rtol * expected.abs()
|
||||
mismatched = (abs_diff > threshold).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user