|
|
|
|
@@ -13,49 +13,87 @@
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from diffusers import QwenImageTransformer2DModel
|
|
|
|
|
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
|
|
|
|
from diffusers.utils.torch_utils import randn_tensor
|
|
|
|
|
|
|
|
|
|
from ...testing_utils import enable_full_determinism, torch_device
|
|
|
|
|
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
|
|
|
|
from ..testing_utils import (
|
|
|
|
|
AttentionTesterMixin,
|
|
|
|
|
BaseModelTesterConfig,
|
|
|
|
|
BitsAndBytesTesterMixin,
|
|
|
|
|
ContextParallelTesterMixin,
|
|
|
|
|
LoraHotSwappingForModelTesterMixin,
|
|
|
|
|
LoraTesterMixin,
|
|
|
|
|
MemoryTesterMixin,
|
|
|
|
|
ModelTesterMixin,
|
|
|
|
|
TorchAoTesterMixin,
|
|
|
|
|
TorchCompileTesterMixin,
|
|
|
|
|
TrainingTesterMixin,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
enable_full_determinism()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
model_class = QwenImageTransformer2DModel
|
|
|
|
|
main_input_name = "hidden_states"
|
|
|
|
|
# We override the items here because the transformer under consideration is small.
|
|
|
|
|
model_split_percents = [0.7, 0.6, 0.6]
|
|
|
|
|
|
|
|
|
|
# Skip setting testing with default: AttnProcessor
|
|
|
|
|
uses_custom_attn_processor = True
|
|
|
|
|
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
|
|
|
|
@property
|
|
|
|
|
def model_class(self):
|
|
|
|
|
return QwenImageTransformer2DModel
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dummy_input(self):
|
|
|
|
|
return self.prepare_dummy_input()
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def input_shape(self):
|
|
|
|
|
def output_shape(self) -> tuple[int, int]:
|
|
|
|
|
return (16, 16)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def output_shape(self):
|
|
|
|
|
def input_shape(self) -> tuple[int, int]:
|
|
|
|
|
return (16, 16)
|
|
|
|
|
|
|
|
|
|
def prepare_dummy_input(self, height=4, width=4):
|
|
|
|
|
@property
|
|
|
|
|
def model_split_percents(self) -> list:
|
|
|
|
|
# We override the items here because the transformer under consideration is small.
|
|
|
|
|
return [0.7, 0.6, 0.6]
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def main_input_name(self) -> str:
|
|
|
|
|
return "hidden_states"
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def uses_custom_attn_processor(self) -> bool:
|
|
|
|
|
# Skip setting testing with default: AttnProcessor
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def generator(self):
|
|
|
|
|
return torch.Generator("cpu").manual_seed(0)
|
|
|
|
|
|
|
|
|
|
def get_init_dict(self) -> dict[str, int | list[int]]:
|
|
|
|
|
return {
|
|
|
|
|
"patch_size": 2,
|
|
|
|
|
"in_channels": 16,
|
|
|
|
|
"out_channels": 4,
|
|
|
|
|
"num_layers": 2,
|
|
|
|
|
"attention_head_dim": 16,
|
|
|
|
|
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
|
|
|
|
"joint_attention_dim": 16,
|
|
|
|
|
"guidance_embeds": False,
|
|
|
|
|
"axes_dims_rope": (8, 4, 4),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
|
|
|
|
batch_size = 1
|
|
|
|
|
num_latent_channels = embedding_dim = 16
|
|
|
|
|
sequence_length = 7
|
|
|
|
|
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
|
|
|
|
vae_scale_factor = 4
|
|
|
|
|
|
|
|
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
|
|
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
|
|
|
|
hidden_states = randn_tensor(
|
|
|
|
|
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states = randn_tensor(
|
|
|
|
|
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
|
|
|
|
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
|
|
|
orig_height = height * 2 * vae_scale_factor
|
|
|
|
|
@@ -70,29 +108,12 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
"img_shapes": img_shapes,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
|
|
|
init_dict = {
|
|
|
|
|
"patch_size": 2,
|
|
|
|
|
"in_channels": 16,
|
|
|
|
|
"out_channels": 4,
|
|
|
|
|
"num_layers": 2,
|
|
|
|
|
"attention_head_dim": 16,
|
|
|
|
|
"num_attention_heads": 3,
|
|
|
|
|
"joint_attention_dim": 16,
|
|
|
|
|
"guidance_embeds": False,
|
|
|
|
|
"axes_dims_rope": (8, 4, 4),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
inputs_dict = self.dummy_input
|
|
|
|
|
return init_dict, inputs_dict
|
|
|
|
|
|
|
|
|
|
def test_gradient_checkpointing_is_applied(self):
|
|
|
|
|
expected_set = {"QwenImageTransformer2DModel"}
|
|
|
|
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
|
|
|
|
def test_infers_text_seq_len_from_mask(self):
|
|
|
|
|
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
|
|
|
|
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
|
init_dict = self.get_init_dict()
|
|
|
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
|
|
|
|
|
|
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
|
|
|
|
@@ -104,55 +125,56 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
|
|
|
|
self.assertIsInstance(rope_text_seq_len, int)
|
|
|
|
|
assert isinstance(rope_text_seq_len, int)
|
|
|
|
|
|
|
|
|
|
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
|
|
|
|
self.assertIsInstance(per_sample_len, torch.Tensor)
|
|
|
|
|
self.assertEqual(int(per_sample_len.max().item()), 2)
|
|
|
|
|
assert isinstance(per_sample_len, torch.Tensor)
|
|
|
|
|
assert int(per_sample_len.max().item()) == 2
|
|
|
|
|
|
|
|
|
|
# Verify mask is normalized to bool dtype
|
|
|
|
|
self.assertTrue(normalized_mask.dtype == torch.bool)
|
|
|
|
|
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
|
|
|
|
assert normalized_mask.dtype == torch.bool
|
|
|
|
|
assert normalized_mask.sum().item() == 2 # Only 2 True values
|
|
|
|
|
|
|
|
|
|
# Verify rope_text_seq_len is at least the sequence length
|
|
|
|
|
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
|
|
|
|
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
# Test 2: Verify model runs successfully with inferred values
|
|
|
|
|
inputs["encoder_hidden_states_mask"] = normalized_mask
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
output = model(**inputs)
|
|
|
|
|
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
# Test 3: Different mask pattern (padding at beginning)
|
|
|
|
|
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
|
|
|
|
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
|
|
|
|
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
|
|
|
|
encoder_hidden_states_mask2[:, 3:] = 1 # Last 5 tokens are valid (seq_len=8)
|
|
|
|
|
|
|
|
|
|
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
|
|
|
|
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Max valid position is 6 (last token), so per_sample_len should be 7
|
|
|
|
|
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
|
|
|
|
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
|
|
|
|
# Max valid position is 7 (last token), so per_sample_len should be 8
|
|
|
|
|
assert int(per_sample_len2.max().item()) == 8
|
|
|
|
|
assert normalized_mask2.sum().item() == 5 # 5 True values
|
|
|
|
|
|
|
|
|
|
# Test 4: No mask provided (None case)
|
|
|
|
|
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
|
|
|
|
inputs["encoder_hidden_states"], None
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
|
|
|
|
self.assertIsInstance(rope_text_seq_len_none, int)
|
|
|
|
|
self.assertIsNone(per_sample_len_none)
|
|
|
|
|
self.assertIsNone(normalized_mask_none)
|
|
|
|
|
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
|
|
|
|
assert isinstance(rope_text_seq_len_none, int)
|
|
|
|
|
assert per_sample_len_none is None
|
|
|
|
|
assert normalized_mask_none is None
|
|
|
|
|
|
|
|
|
|
def test_non_contiguous_attention_mask(self):
|
|
|
|
|
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
|
|
|
|
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
|
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0, 0])"""
|
|
|
|
|
init_dict = self.get_init_dict()
|
|
|
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
|
|
|
|
|
|
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
|
|
|
|
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
|
|
|
|
# Pattern: [True, False, True, False, True, False, False]
|
|
|
|
|
# Pattern: [True, False, True, False, True, False, False, False] (seq_len=8)
|
|
|
|
|
encoder_hidden_states_mask[:, 1] = 0
|
|
|
|
|
encoder_hidden_states_mask[:, 3] = 0
|
|
|
|
|
encoder_hidden_states_mask[:, 5:] = 0
|
|
|
|
|
@@ -160,21 +182,22 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
|
|
|
|
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
|
|
|
|
)
|
|
|
|
|
self.assertEqual(int(per_sample_len.max().item()), 5)
|
|
|
|
|
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
|
|
|
|
self.assertIsInstance(inferred_rope_len, int)
|
|
|
|
|
self.assertTrue(normalized_mask.dtype == torch.bool)
|
|
|
|
|
assert int(per_sample_len.max().item()) == 5
|
|
|
|
|
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
|
|
|
|
assert isinstance(inferred_rope_len, int)
|
|
|
|
|
assert normalized_mask.dtype == torch.bool
|
|
|
|
|
|
|
|
|
|
inputs["encoder_hidden_states_mask"] = normalized_mask
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
output = model(**inputs)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
def test_txt_seq_lens_deprecation(self):
|
|
|
|
|
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
|
|
|
|
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
|
init_dict = self.get_init_dict()
|
|
|
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
|
|
|
|
|
|
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
|
|
|
|
@@ -186,18 +209,24 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
|
|
|
|
|
|
|
|
|
# Test that deprecation warning is raised
|
|
|
|
|
with self.assertWarns(FutureWarning) as warning_context:
|
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
|
|
|
warnings.simplefilter("always")
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
output = model(**inputs_with_deprecated)
|
|
|
|
|
|
|
|
|
|
# Verify the warning message mentions the deprecation
|
|
|
|
|
warning_message = str(warning_context.warning)
|
|
|
|
|
self.assertIn("txt_seq_lens", warning_message)
|
|
|
|
|
self.assertIn("deprecated", warning_message)
|
|
|
|
|
self.assertIn("encoder_hidden_states_mask", warning_message)
|
|
|
|
|
# Verify a FutureWarning was raised
|
|
|
|
|
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
|
|
|
|
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
|
|
|
|
|
|
|
|
|
# Verify the warning message mentions the deprecation
|
|
|
|
|
warning_message = str(future_warnings[0].message)
|
|
|
|
|
assert "txt_seq_lens" in warning_message
|
|
|
|
|
assert "deprecated" in warning_message
|
|
|
|
|
|
|
|
|
|
# Verify the model still works correctly despite the deprecation
|
|
|
|
|
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
def test_layered_model_with_mask(self):
|
|
|
|
|
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
|
|
|
|
@@ -208,7 +237,7 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
"out_channels": 4,
|
|
|
|
|
"num_layers": 2,
|
|
|
|
|
"attention_head_dim": 16,
|
|
|
|
|
"num_attention_heads": 3,
|
|
|
|
|
"num_attention_heads": 4, # Must be divisible by 2 for Ulysses context parallel
|
|
|
|
|
"joint_attention_dim": 16,
|
|
|
|
|
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
|
|
|
|
"use_layer3d_rope": True, # Enable layered RoPE
|
|
|
|
|
@@ -220,11 +249,11 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
# Verify the model uses QwenEmbedLayer3DRope
|
|
|
|
|
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
|
|
|
|
|
|
|
|
|
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
|
|
|
|
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
|
|
|
|
|
|
|
|
|
# Test single generation with layered structure
|
|
|
|
|
batch_size = 1
|
|
|
|
|
text_seq_len = 7
|
|
|
|
|
text_seq_len = 8
|
|
|
|
|
img_h, img_w = 4, 4
|
|
|
|
|
layers = 4
|
|
|
|
|
|
|
|
|
|
@@ -262,24 +291,104 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
|
|
|
|
additional_t_cond=addition_t_cond,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
|
|
|
|
assert output.sample.shape[1] == hidden_states.shape[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
|
|
|
|
model_class = QwenImageTransformer2DModel
|
|
|
|
|
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
|
|
|
|
"""Memory optimization tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
def prepare_init_args_and_inputs_for_common(self):
|
|
|
|
|
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
|
|
|
|
|
|
|
|
|
def prepare_dummy_input(self, height, width):
|
|
|
|
|
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
|
|
|
|
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
|
|
|
|
"""Training tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
def test_torch_compile_recompilation_and_graph_break(self):
|
|
|
|
|
super().test_torch_compile_recompilation_and_graph_break()
|
|
|
|
|
def test_gradient_checkpointing_is_applied(self):
|
|
|
|
|
expected_set = {"QwenImageTransformer2DModel"}
|
|
|
|
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
|
|
|
|
"""Attention processor tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
|
|
|
|
"""Context Parallel inference tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
|
|
|
|
"""LoRA adapter tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
|
|
|
|
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def different_shapes_for_compilation(self):
|
|
|
|
|
return [(4, 4), (4, 8), (8, 8)]
|
|
|
|
|
|
|
|
|
|
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
|
|
|
|
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
|
|
|
|
batch_size = 1
|
|
|
|
|
num_latent_channels = embedding_dim = 16
|
|
|
|
|
sequence_length = 8
|
|
|
|
|
vae_scale_factor = 4
|
|
|
|
|
|
|
|
|
|
hidden_states = randn_tensor(
|
|
|
|
|
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states = randn_tensor(
|
|
|
|
|
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
|
|
|
|
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
|
|
|
orig_height = height * 2 * vae_scale_factor
|
|
|
|
|
orig_width = width * 2 * vae_scale_factor
|
|
|
|
|
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"hidden_states": hidden_states,
|
|
|
|
|
"encoder_hidden_states": encoder_hidden_states,
|
|
|
|
|
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
|
|
|
|
"timestep": timestep,
|
|
|
|
|
"img_shapes": img_shapes,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
|
|
|
|
@property
|
|
|
|
|
def different_shapes_for_compilation(self):
|
|
|
|
|
return [(4, 4), (4, 8), (8, 8)]
|
|
|
|
|
|
|
|
|
|
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
|
|
|
|
"""Override to support dynamic height/width for compilation tests."""
|
|
|
|
|
batch_size = 1
|
|
|
|
|
num_latent_channels = embedding_dim = 16
|
|
|
|
|
sequence_length = 8 # Must be divisible by 2 for context parallel tests
|
|
|
|
|
vae_scale_factor = 4
|
|
|
|
|
|
|
|
|
|
hidden_states = randn_tensor(
|
|
|
|
|
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states = randn_tensor(
|
|
|
|
|
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
|
|
|
|
)
|
|
|
|
|
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
|
|
|
|
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
|
|
|
|
orig_height = height * 2 * vae_scale_factor
|
|
|
|
|
orig_width = width * 2 * vae_scale_factor
|
|
|
|
|
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"hidden_states": hidden_states,
|
|
|
|
|
"encoder_hidden_states": encoder_hidden_states,
|
|
|
|
|
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
|
|
|
|
"timestep": timestep,
|
|
|
|
|
"img_shapes": img_shapes,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def test_torch_compile_with_and_without_mask(self):
|
|
|
|
|
"""Test that torch.compile works with both None mask and padding mask."""
|
|
|
|
|
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
|
|
|
|
init_dict = self.get_init_dict()
|
|
|
|
|
inputs = self.get_dummy_inputs()
|
|
|
|
|
model = self.model_class(**init_dict).to(torch_device)
|
|
|
|
|
model.eval()
|
|
|
|
|
model.compile(mode="default", fullgraph=True)
|
|
|
|
|
@@ -300,13 +409,13 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
|
|
|
|
):
|
|
|
|
|
output_no_mask_2 = model(**inputs_no_mask)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
# Test 2: Run with all-ones mask (should behave like None)
|
|
|
|
|
inputs_all_ones = inputs.copy()
|
|
|
|
|
# Keep the all-ones mask
|
|
|
|
|
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
|
|
|
|
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
|
|
|
|
|
|
|
|
|
# First run to allow compilation
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
@@ -320,8 +429,8 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
|
|
|
|
):
|
|
|
|
|
output_all_ones_2 = model(**inputs_all_ones)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
# Test 3: Run with actual padding mask (has zeros)
|
|
|
|
|
inputs_with_padding = inputs.copy()
|
|
|
|
|
@@ -342,8 +451,16 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
|
|
|
|
):
|
|
|
|
|
output_with_padding_2 = model(**inputs_with_padding)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
|
|
|
|
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
|
|
|
|
|
|
|
|
|
# Verify that outputs are different (mask should affect results)
|
|
|
|
|
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
|
|
|
|
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
|
|
|
|
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
|
|
|
|
"""TorchAO quantization tests for QwenImage Transformer."""
|
|
|
|
|
|