This commit is contained in:
sayakpaul
2026-02-16 13:10:04 +05:30
parent 0e42a3ff93
commit 2b67fb65ef

View File

@@ -765,11 +765,7 @@ class TestUNet2DConditionTraining(UNet2DConditionTesterConfig, TrainingTesterMix
"Transformer2DModel",
"DownBlock2D",
}
attention_head_dim = (8, 16)
block_out_channels = (16, 32)
super().test_gradient_checkpointing_is_applied(
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
)
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestUNet2DConditionAttention(UNet2DConditionTesterConfig, AttentionTesterMixin):
@@ -988,7 +984,7 @@ class TestUNet2DConditionCustomDiffusion(UNet2DConditionTesterConfig):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
assert os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
@@ -1038,7 +1034,7 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
@property
def ip_adapter_processor_cls(self):
return IPAdapterAttnProcessor
return (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)
def create_ip_adapter_state_dict(self, model):
return create_ip_adapter_state_dict(model)