mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-27 02:47:41 +08:00
Compare commits
2 Commits
glmimage-r
...
cog-tests
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f7405f2b44 | ||
|
|
b757035df6 |
1
.github/workflows/claude_review.yml
vendored
1
.github/workflows/claude_review.yml
vendored
@@ -10,6 +10,7 @@ permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
issues: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
claude-review:
|
||||
|
||||
@@ -13,59 +13,53 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CogVideoXTransformer3DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CogVideoXTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
model_split_percents = [0.7, 0.7, 0.8]
|
||||
# ======================== CogVideoX ========================
|
||||
|
||||
|
||||
class CogVideoXTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CogVideoXTransformer3DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
num_frames = 1
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.7, 0.8]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"in_channels": 4,
|
||||
@@ -81,50 +75,66 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"temporal_compression_ratio": 4,
|
||||
"max_text_seq_length": 8,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogVideoXTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CogVideoXTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
num_frames = 1
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCogVideoXTransformer(CogVideoXTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestCogVideoXTransformerTraining(CogVideoXTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogVideoXTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestCogVideoXTransformerCompile(CogVideoXTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
# ======================== CogVideoX 1.5 ========================
|
||||
|
||||
|
||||
class CogVideoX15TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def input_shape(self):
|
||||
def model_class(self):
|
||||
return CogVideoXTransformer3DModel
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple:
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"in_channels": 4,
|
||||
@@ -141,9 +151,29 @@ class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"max_text_seq_length": 8,
|
||||
"use_rotary_positional_embeddings": True,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogVideoXTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCogVideoX15Transformer(CogVideoX15TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestCogVideoX15TransformerCompile(CogVideoX15TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
@@ -13,63 +13,50 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CogView3PlusTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
torch_device,
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CogView3PlusTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
class CogView3PlusTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CogView3PlusTransformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
|
||||
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
|
||||
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"original_size": original_size,
|
||||
"target_size": target_size,
|
||||
"crop_coords": crop_coords,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (1, 4, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"num_layers": 2,
|
||||
@@ -82,9 +69,37 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"pos_embed_max_size": 8,
|
||||
"sample_size": 8,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
|
||||
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
|
||||
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCogView3PlusTransformer(CogView3PlusTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestCogView3PlusTransformerTraining(CogView3PlusTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogView3PlusTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestCogView3PlusTransformerCompile(CogView3PlusTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
@@ -12,59 +12,46 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import CogView4Transformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = CogView4Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
class CogView4TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return CogView4Transformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
|
||||
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
|
||||
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"original_size": original_size,
|
||||
"target_size": target_size,
|
||||
"crop_coords": crop_coords,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 8, 8)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"num_layers": 2,
|
||||
@@ -75,9 +62,37 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"time_embed_dim": 8,
|
||||
"condition_dim": 4,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = 8
|
||||
width = 8
|
||||
embedding_dim = 8
|
||||
sequence_length = 8
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
|
||||
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
|
||||
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestCogView4Transformer(CogView4TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestCogView4TransformerTraining(CogView4TransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"CogView4Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestCogView4TransformerCompile(CogView4TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import GlmImageTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class GlmImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return GlmImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 8, 8)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (4, 8, 8)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"num_layers": 1,
|
||||
"attention_head_dim": 8,
|
||||
"num_attention_heads": 2,
|
||||
"text_embed_dim": 32,
|
||||
"time_embed_dim": 16,
|
||||
"condition_dim": 8,
|
||||
"prior_vq_quantizer_codebook_size": 64,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_channels = 4
|
||||
height = width = 8
|
||||
sequence_length = 12
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, 32), generator=self.generator, device=torch_device
|
||||
),
|
||||
"prior_token_id": torch.randint(0, 64, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"prior_token_drop": torch.zeros(batch_size, dtype=torch.bool, device=torch_device),
|
||||
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
|
||||
"target_size": torch.tensor([[height, width]] * batch_size, dtype=torch.float32).to(torch_device),
|
||||
"crop_coords": torch.tensor([[0, 0]] * batch_size, dtype=torch.float32).to(torch_device),
|
||||
}
|
||||
|
||||
|
||||
class TestGlmImageTransformer(GlmImageTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestGlmImageTransformerTraining(GlmImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"GlmImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestGlmImageTransformerCompile(GlmImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
Reference in New Issue
Block a user