From ea08148bbd195878a6b44ce1142de8d134ec4e2a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Feb 2026 15:41:57 +0530 Subject: [PATCH] recompile limit --- tests/models/testing_utils/compile.py | 3 --- tests/models/unets/test_models_unet_2d_condition.py | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/testing_utils/compile.py b/tests/models/testing_utils/compile.py index 998b88fb46..4787d0742b 100644 --- a/tests/models/testing_utils/compile.py +++ b/tests/models/testing_utils/compile.py @@ -92,9 +92,6 @@ class TorchCompileTesterMixin: model.eval() model.compile_repeated_blocks(fullgraph=True) - if self.model_class.__name__ == "UNet2DConditionModel": - recompile_limit = 2 - with ( torch._inductor.utils.fresh_inductor_cache(), torch._dynamo.config.patch(recompile_limit=recompile_limit), diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 5fdada7793..a7293208d3 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1199,6 +1199,9 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin): """Torch compile tests for UNet2DConditionModel.""" + def test_torch_compile_repeated_blocks(self): + return super().test_torch_compile_repeated_blocks(recompile_limit=2) + class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin): """LoRA hot-swapping tests for UNet2DConditionModel."""