Compare commits

...

4 Commits

Author SHA1 Message Date
DN6
6c75674f8c update 2026-03-26 16:13:00 +05:30
DN6
598daed3c8 update 2026-03-26 10:48:07 +05:30
DN6
dede05dccf update 2026-03-26 10:45:03 +05:30
DN6
ed1ad8e3c2 update 2026-03-26 10:37:18 +05:30
4 changed files with 409 additions and 375 deletions

View File

@@ -12,71 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideo15Transformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideo15Transformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.99, 0.99, 0.99]
class HunyuanVideo15TransformerTesterConfig(BaseModelTesterConfig):
text_embed_dim = 16
text_embed_2_dim = 8
image_embed_dim = 12
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 8
width = 8
sequence_length = 6
sequence_length_2 = 4
image_sequence_length = 3
def model_class(self):
return HunyuanVideo15Transformer3DModel
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, self.text_embed_dim), device=torch_device)
encoder_hidden_states_2 = torch.randn(
(batch_size, sequence_length_2, self.text_embed_2_dim), device=torch_device
)
encoder_attention_mask = torch.ones((batch_size, sequence_length), device=torch_device)
encoder_attention_mask_2 = torch.ones((batch_size, sequence_length_2), device=torch_device)
# All zeros for inducing T2V path in the model.
image_embeds = torch.zeros((batch_size, image_sequence_length, self.image_embed_dim), device=torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def model_split_percents(self) -> list:
return [0.99, 0.99, 0.99]
@property
def output_shape(self) -> tuple:
return (4, 1, 8, 8)
@property
def input_shape(self) -> tuple:
return (4, 1, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"encoder_hidden_states_2": encoder_hidden_states_2,
"encoder_attention_mask_2": encoder_attention_mask_2,
"image_embeds": image_embeds,
}
@property
def input_shape(self):
return (4, 1, 8, 8)
@property
def output_shape(self):
return (4, 1, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -93,9 +75,40 @@ class HunyuanVideo15Transformer3DTests(ModelTesterMixin, unittest.TestCase):
"target_size": 16,
"task_type": "t2v",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 8
width = 8
sequence_length = 6
sequence_length_2 = 4
image_sequence_length = 3
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, self.text_embed_dim), generator=self.generator, device=torch_device
),
"encoder_hidden_states_2": randn_tensor(
(batch_size, sequence_length_2, self.text_embed_2_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length), device=torch_device),
"encoder_attention_mask_2": torch.ones((batch_size, sequence_length_2), device=torch_device),
"image_embeds": torch.zeros(
(batch_size, image_sequence_length, self.image_embed_dim), device=torch_device
),
}
class TestHunyuanVideo15Transformer(HunyuanVideo15TransformerTesterConfig, ModelTesterMixin):
pass
class TestHunyuanVideo15TransformerTraining(HunyuanVideo15TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideo15Transformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -13,75 +13,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanDiT2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanDiT2DModel
main_input_name = "hidden_states"
class HunyuanDiTTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanDiT2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = 8
embedding_dim = 8
sequence_length = 4
sequence_length_t5 = 4
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device)
encoder_hidden_states_t5 = torch.randn((batch_size, sequence_length_t5, embedding_dim)).to(torch_device)
text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,), dtype=encoder_hidden_states.dtype).to(torch_device)
original_size = [1024, 1024]
target_size = [16, 16]
crops_coords_top_left = [0, 0]
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids, add_time_ids], dtype=encoder_hidden_states.dtype).to(torch_device)
style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device)
image_rotary_emb = [
torch.ones(size=(1, 8), dtype=encoder_hidden_states.dtype),
torch.zeros(size=(1, 8), dtype=encoder_hidden_states.dtype),
]
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"text_embedding_mask": text_embedding_mask,
"encoder_hidden_states_t5": encoder_hidden_states_t5,
"text_embedding_mask_t5": text_embedding_mask_t5,
"timestep": timestep,
"image_meta_size": add_time_ids,
"style": style,
"image_rotary_emb": image_rotary_emb,
}
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-hunyuan-dit-pipe"
@property
def input_shape(self):
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (8, 8, 8)
@property
def input_shape(self) -> tuple:
return (4, 8, 8)
@property
def output_shape(self):
return (8, 8, 8)
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict:
return {
"sample_size": 8,
"patch_size": 2,
"in_channels": 4,
@@ -96,18 +77,70 @@ class HunyuanDiTTests(ModelTesterMixin, unittest.TestCase):
"text_len_t5": 4,
"activation_fn": "gelu-approximate",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = width = 8
embedding_dim = 8
sequence_length = 4
sequence_length_t5 = 4
hidden_states = randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
text_embedding_mask = torch.ones(size=(batch_size, sequence_length)).to(torch_device)
encoder_hidden_states_t5 = randn_tensor(
(batch_size, sequence_length_t5, embedding_dim), generator=self.generator, device=torch_device
)
text_embedding_mask_t5 = torch.ones(size=(batch_size, sequence_length_t5)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).float().to(torch_device)
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
def test_set_xformers_attn_processor_for_determinism(self):
pass
original_size = [1024, 1024]
target_size = [16, 16]
crops_coords_top_left = [0, 0]
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids] * batch_size, dtype=torch.float32).to(torch_device)
style = torch.zeros(size=(batch_size,), dtype=int).to(torch_device)
image_rotary_emb = [
torch.ones(size=(1, 8), dtype=torch.float32),
torch.zeros(size=(1, 8), dtype=torch.float32),
]
@unittest.skip("HunyuanDIT use a custom processor HunyuanAttnProcessor2_0")
def test_set_attn_processor_for_determinism(self):
pass
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"text_embedding_mask": text_embedding_mask,
"encoder_hidden_states_t5": encoder_hidden_states_t5,
"text_embedding_mask_t5": text_embedding_mask_t5,
"timestep": timestep,
"image_meta_size": add_time_ids,
"style": style,
"image_rotary_emb": image_rotary_emb,
}
class TestHunyuanDiT(HunyuanDiTTesterConfig, ModelTesterMixin):
def test_output(self):
batch_size = self.get_dummy_inputs()[self.main_input_name].shape[0]
super().test_output(expected_output_shape=(batch_size,) + self.output_shape)
class TestHunyuanDiTTraining(HunyuanDiTTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanDiT2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestHunyuanDiTCompile(HunyuanDiTTesterConfig, TorchCompileTesterMixin):
pass
class TestHunyuanDiTBitsAndBytes(HunyuanDiTTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for HunyuanDiT."""
class TestHunyuanDiTTorchAo(HunyuanDiTTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for HunyuanDiT."""

View File

@@ -12,64 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideoTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
# ======================== HunyuanVideo Text-to-Video ========================
class HunyuanVideoTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-random-hunyuanvideo"
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def input_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -85,30 +80,9 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 8
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
@@ -116,105 +90,74 @@ class HunyuanSkyreelsImageToVideoTransformer3DTests(ModelTesterMixin, unittest.T
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(
torch_device, dtype=torch.float32
),
}
@property
def input_shape(self):
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
class TestHunyuanVideoTransformer(HunyuanVideoTransformerTesterConfig, ModelTesterMixin):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 8,
"out_channels": 4,
"num_attention_heads": 2,
"attention_head_dim": 10,
"num_layers": 1,
"num_single_layers": 1,
"num_refiner_layers": 1,
"patch_size": 1,
"patch_size_t": 1,
"guidance_embeds": True,
"text_embed_dim": 16,
"pooled_projection_dim": 8,
"rope_axes_dim": (2, 4, 4),
"image_condition_type": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
class TestHunyuanVideoTransformerTraining(HunyuanVideoTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanSkyreelsImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanSkyreelsImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestHunyuanVideoTransformerCompile(HunyuanVideoTransformerTesterConfig, TorchCompileTesterMixin):
pass
class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class TestHunyuanVideoTransformerBitsAndBytes(HunyuanVideoTransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for HunyuanVideo Transformer."""
class TestHunyuanVideoTransformerTorchAo(HunyuanVideoTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for HunyuanVideo Transformer."""
# ======================== HunyuanVideo Image-to-Video (Latent Concat) ========================
class HunyuanVideoI2VTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 2 * 4 + 1
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
}
def main_input_name(self) -> str:
return "hidden_states"
@property
def input_shape(self):
def output_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def input_shape(self) -> tuple:
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict:
return {
"in_channels": 2 * 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
@@ -230,33 +173,9 @@ class HunyuanVideoImageToVideoTransformer3DTests(ModelTesterMixin, unittest.Test
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "latent_concat",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanImageToVideoCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
num_channels = 2
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 2 * 4 + 1
num_frames = 1
height = 16
width = 16
@@ -264,32 +183,58 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, u
pooled_projection_dim = 8
sequence_length = 12
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device, dtype=torch.float32)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
}
class TestHunyuanVideoI2VTransformer(HunyuanVideoI2VTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
class TestHunyuanVideoI2VTransformerCompile(HunyuanVideoI2VTransformerTesterConfig, TorchCompileTesterMixin):
pass
# ======================== HunyuanVideo Token Replace Image-to-Video ========================
class HunyuanVideoTokenReplaceTransformerTesterConfig(BaseModelTesterConfig):
@property
def input_shape(self):
def model_class(self):
return HunyuanVideoTransformer3DModel
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (4, 1, 16, 16)
@property
def input_shape(self) -> tuple:
return (8, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict:
return {
"in_channels": 2,
"out_channels": 4,
"num_attention_heads": 2,
@@ -305,19 +250,42 @@ class HunyuanVideoTokenReplaceImageToVideoTransformer3DTests(ModelTesterMixin, u
"rope_axes_dim": (2, 4, 4),
"image_condition_type": "token_replace",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 2
num_frames = 1
height = 16
width = 16
text_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(
torch_device, dtype=torch.float32
),
}
class TestHunyuanVideoTokenReplaceTransformer(HunyuanVideoTokenReplaceTransformerTesterConfig, ModelTesterMixin):
def test_output(self):
super().test_output(expected_output_shape=(1, *self.output_shape))
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class HunyuanVideoTokenReplaceCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = HunyuanVideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return HunyuanVideoTokenReplaceImageToVideoTransformer3DTests().prepare_init_args_and_inputs_for_common()
class TestHunyuanVideoTokenReplaceTransformerCompile(
HunyuanVideoTokenReplaceTransformerTesterConfig, TorchCompileTesterMixin
):
pass

View File

@@ -12,84 +12,49 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import HunyuanVideoFramepackTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
model_class = HunyuanVideoFramepackTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.5, 0.7, 0.9]
class HunyuanVideoFramepackTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HunyuanVideoFramepackTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 3
height = 4
width = 4
text_encoder_embedding_dim = 16
image_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_encoder_embedding_dim)).to(torch_device)
pooled_projections = torch.randn((batch_size, pooled_projection_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
image_embeds = torch.randn((batch_size, sequence_length, image_encoder_embedding_dim)).to(torch_device)
indices_latents = torch.ones((3,)).to(torch_device)
latents_clean = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
indices_latents_clean = torch.ones((num_frames - 1,)).to(torch_device)
latents_history_2x = torch.randn((batch_size, num_channels, num_frames - 1, height, width)).to(torch_device)
indices_latents_history_2x = torch.ones((num_frames - 1,)).to(torch_device)
latents_history_4x = torch.randn((batch_size, num_channels, (num_frames - 1) * 4, height, width)).to(
torch_device
)
indices_latents_history_4x = torch.ones(((num_frames - 1) * 4,)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
guidance = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def model_split_percents(self) -> list:
return [0.5, 0.7, 0.9]
@property
def output_shape(self) -> tuple:
return (4, 3, 4, 4)
@property
def input_shape(self) -> tuple:
return (4, 3, 4, 4)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_projections,
"encoder_attention_mask": encoder_attention_mask,
"guidance": guidance,
"image_embeds": image_embeds,
"indices_latents": indices_latents,
"latents_clean": latents_clean,
"indices_latents_clean": indices_latents_clean,
"latents_history_2x": latents_history_2x,
"indices_latents_history_2x": indices_latents_history_2x,
"latents_history_4x": latents_history_4x,
"indices_latents_history_4x": indices_latents_history_4x,
}
@property
def input_shape(self):
return (4, 3, 4, 4)
@property
def output_shape(self):
return (4, 3, 4, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -108,9 +73,64 @@ class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase):
"image_proj_dim": 16,
"has_clean_x_embedder": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 3
height = 4
width = 4
text_encoder_embedding_dim = 16
image_encoder_embedding_dim = 16
pooled_projection_dim = 8
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"pooled_projections": randn_tensor(
(batch_size, pooled_projection_dim), generator=self.generator, device=torch_device
),
"encoder_attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"guidance": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"image_embeds": randn_tensor(
(batch_size, sequence_length, image_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
),
"indices_latents": torch.ones((num_frames,)).to(torch_device),
"latents_clean": randn_tensor(
(batch_size, num_channels, num_frames - 1, height, width),
generator=self.generator,
device=torch_device,
),
"indices_latents_clean": torch.ones((num_frames - 1,)).to(torch_device),
"latents_history_2x": randn_tensor(
(batch_size, num_channels, num_frames - 1, height, width),
generator=self.generator,
device=torch_device,
),
"indices_latents_history_2x": torch.ones((num_frames - 1,)).to(torch_device),
"latents_history_4x": randn_tensor(
(batch_size, num_channels, (num_frames - 1) * 4, height, width),
generator=self.generator,
device=torch_device,
),
"indices_latents_history_4x": torch.ones(((num_frames - 1) * 4,)).to(torch_device),
}
class TestHunyuanVideoFramepackTransformer(HunyuanVideoFramepackTransformerTesterConfig, ModelTesterMixin):
pass
class TestHunyuanVideoFramepackTransformerTraining(HunyuanVideoFramepackTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HunyuanVideoFramepackTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)