mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-06 10:54:59 +08:00
Compare commits
2 Commits
qwen-test-
...
wan-test-r
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8a3ef8a52 | ||
|
|
b712042da1 |
@@ -2321,8 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
|
||||
prefix = "diffusion_model."
|
||||
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
|
||||
|
||||
num_double_layers = 8
|
||||
num_single_layers = 48
|
||||
num_double_layers = 0
|
||||
num_single_layers = 0
|
||||
for key in original_state_dict.keys():
|
||||
if key.startswith("single_blocks."):
|
||||
num_single_layers = max(num_single_layers, int(key.split(".")[1]) + 1)
|
||||
elif key.startswith("double_blocks."):
|
||||
num_double_layers = max(num_double_layers, int(key.split(".")[1]) + 1)
|
||||
|
||||
lora_keys = ("lora_A", "lora_B")
|
||||
attn_types = ("img_attn", "txt_attn")
|
||||
|
||||
|
||||
@@ -13,87 +13,49 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return QwenImageTransformer2DModel
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
def output_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
||||
sequence_length = 7
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
@@ -108,12 +70,29 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
@@ -125,56 +104,55 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
assert isinstance(rope_text_seq_len, int)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2 # Only 2 True values
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 5 tokens are valid (seq_len=8)
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 7 (last token), so per_sample_len should be 8
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5 # 5 True values
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(rope_text_seq_len_none, int)
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0, 0])"""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False, False] (seq_len=8)
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
@@ -182,22 +160,21 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
assert int(per_sample_len.max().item()) == 5
|
||||
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(inferred_rope_len, int)
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
@@ -209,24 +186,18 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify a FutureWarning was raised
|
||||
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(future_warnings[0].message)
|
||||
assert "txt_seq_lens" in warning_message
|
||||
assert "deprecated" in warning_message
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
@@ -237,7 +208,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
@@ -249,11 +220,11 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 8
|
||||
text_seq_len = 7
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
@@ -291,104 +262,24 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
assert output.sample.shape[1] == hidden_states.shape[1]
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
|
||||
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for QwenImage Transformer."""
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for QwenImage Transformer."""
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
@@ -409,13 +300,13 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
@@ -429,8 +320,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
@@ -451,16 +342,8 @@ class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCom
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
||||
|
||||
|
||||
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for QwenImage Transformer."""
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
|
||||
@@ -12,57 +12,52 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
class WanTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanTransformer3DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
@@ -76,16 +71,118 @@ class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3D(WanTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DMemory(WanTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DTraining(WanTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class WanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanTransformer3DModel
|
||||
class TestWanTransformer3DAttention(WanTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan Transformer 3D."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestWanTransformer3DCompile(WanTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DBitsAndBytes(WanTransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DTorchAo(WanTransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanTransformer3DGGUF(WanTransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan I2V model dimensions.
|
||||
|
||||
Wan 2.2 I2V: in_channels=36, text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanTransformer3DGGUFCompile(WanTransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-I2V-A14B-GGUF/blob/main/LowNoise/Wan2.2-I2V-A14B-LowNoise-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan I2V model dimensions.
|
||||
|
||||
Wan 2.2 I2V: in_channels=36, text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
@@ -12,76 +12,57 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import WanAnimateTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = WanAnimateTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
class WanAnimateTransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanAnimateTransformer3DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
clip_seq_len = 12
|
||||
clip_dim = 16
|
||||
|
||||
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
|
||||
face_height = 16 # Should be square and match `motion_encoder_size` below
|
||||
face_width = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, 2 * num_channels + 4, num_frames + 1, height, width)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
|
||||
clip_ref_features = torch.randn((batch_size, clip_seq_len, clip_dim)).to(torch_device)
|
||||
pose_latents = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
face_pixel_values = torch.randn((batch_size, 3, inference_segment_length, face_height, face_width)).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_image": clip_ref_features,
|
||||
"pose_hidden_states": pose_latents,
|
||||
"face_pixel_values": face_pixel_values,
|
||||
}
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
# Output has fewer channels than input (4 vs 12)
|
||||
return (4, 21, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (12, 1, 16, 16)
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (12, 21, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 1, 16, 16)
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | float | dict]:
|
||||
# Use custom channel sizes since the default Wan Animate channel sizes will cause the motion encoder to
|
||||
# contain the vast majority of the parameters in the test model
|
||||
channel_sizes = {"4": 16, "8": 16, "16": 16}
|
||||
|
||||
init_dict = {
|
||||
return {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
@@ -105,22 +86,158 @@ class WanAnimateTransformer3DTests(ModelTesterMixin, unittest.TestCase):
|
||||
"face_encoder_num_heads": 2,
|
||||
"inject_face_latents_blocks": 2,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 4
|
||||
num_frames = 20 # To make the shapes work out; for complicated reasons we want 21 to divide num_frames + 1
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 16
|
||||
sequence_length = 12
|
||||
|
||||
clip_seq_len = 12
|
||||
clip_dim = 16
|
||||
|
||||
inference_segment_length = 77 # The inference segment length in the full Wan2.2-Animate-14B model
|
||||
face_height = 16 # Should be square and match `motion_encoder_size`
|
||||
face_width = 16
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, 2 * num_channels + 4, num_frames + 1, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(batch_size, clip_seq_len, clip_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(batch_size, 3, inference_segment_length, face_height, face_width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3D(WanAnimateTransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def test_output(self):
|
||||
# Override test_output because the transformer output is expected to have less channels
|
||||
# than the main transformer input.
|
||||
expected_output_shape = (1, 4, 21, 16, 16)
|
||||
super().test_output(expected_output_shape=expected_output_shape)
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DMemory(WanAnimateTransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DTraining(WanAnimateTransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanAnimateTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# Override test_output because the transformer output is expected to have less channels than the main transformer
|
||||
# input.
|
||||
def test_output(self):
|
||||
expected_output_shape = (1, 4, 21, 16, 16)
|
||||
super().test_output(expected_output_shape=expected_output_shape)
|
||||
|
||||
class TestWanAnimateTransformer3DAttention(WanAnimateTransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class WanAnimateTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = WanAnimateTransformer3DModel
|
||||
class TestWanAnimateTransformer3DCompile(WanAnimateTransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan Animate Transformer 3D."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return WanAnimateTransformer3DTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
class TestWanAnimateTransformer3DBitsAndBytes(WanAnimateTransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DTorchAo(WanAnimateTransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DGGUF(WanAnimateTransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan Animate model dimensions.
|
||||
|
||||
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanAnimateTransformer3DGGUFCompile(WanAnimateTransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan Animate Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.2-Animate-14B-GGUF/blob/main/Wan2.2-Animate-14B-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan Animate model dimensions.
|
||||
|
||||
Wan 2.2 Animate: in_channels=36 (2*16+4), text_dim=4096, image_dim=1280
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 36, 21, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states_image": randn_tensor(
|
||||
(1, 257, 1280), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"pose_hidden_states": randn_tensor(
|
||||
(1, 16, 20, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"face_pixel_values": randn_tensor(
|
||||
(1, 3, 77, 512, 512), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
198
tests/models/transformers/test_models_transformer_wan_vace.py
Normal file
198
tests/models/transformers/test_models_transformer_wan_vace.py
Normal file
@@ -0,0 +1,198 @@
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import WanVACETransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class WanVACETransformer3DTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return WanVACETransformer3DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, ...]:
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, ...]:
|
||||
return (16, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool | None]:
|
||||
return {
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"text_dim": 32,
|
||||
"freq_dim": 256,
|
||||
"ffn_dim": 32,
|
||||
"num_layers": 4,
|
||||
"cross_attn_norm": True,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"rope_max_seq_len": 32,
|
||||
"vace_layers": [0, 2],
|
||||
"vace_in_channels": 48, # 3 * in_channels = 3 * 16 = 48
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_channels = 16
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
text_encoder_embedding_dim = 32
|
||||
sequence_length = 12
|
||||
|
||||
# VACE requires control_hidden_states with vace_in_channels (3 * in_channels)
|
||||
vace_in_channels = 48
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, text_encoder_embedding_dim),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(batch_size, vace_in_channels, num_frames, height, width),
|
||||
generator=self.generator,
|
||||
device=torch_device,
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3D(WanVACETransformer3DTesterConfig, ModelTesterMixin):
|
||||
"""Core model tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DMemory(WanVACETransformer3DTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DTraining(WanVACETransformer3DTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Wan VACE Transformer 3D."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"WanVACETransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestWanVACETransformer3DAttention(WanVACETransformer3DTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DCompile(WanVACETransformer3DTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DBitsAndBytes(WanVACETransformer3DTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DTorchAo(WanVACETransformer3DTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
|
||||
class TestWanVACETransformer3DGGUF(WanVACETransformer3DTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan VACE model dimensions.
|
||||
|
||||
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
|
||||
|
||||
class TestWanVACETransformer3DGGUFCompile(WanVACETransformer3DTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Wan VACE Transformer 3D."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/QuantStack/Wan2.1_14B_VACE-GGUF/blob/main/Wan2.1_14B_VACE-Q3_K_S.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real Wan VACE model dimensions.
|
||||
|
||||
Wan 2.1 VACE: in_channels=16, text_dim=4096, vace_in_channels=96
|
||||
"""
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(1, 16, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(1, 512, 4096), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"control_hidden_states": randn_tensor(
|
||||
(1, 96, 2, 64, 64), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device, self.torch_dtype),
|
||||
}
|
||||
Reference in New Issue
Block a user