mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-20 18:00:46 +08:00
recompile limit
This commit is contained in:
@@ -92,9 +92,6 @@ class TorchCompileTesterMixin:
|
||||
model.eval()
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
|
||||
if self.model_class.__name__ == "UNet2DConditionModel":
|
||||
recompile_limit = 2
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(recompile_limit=recompile_limit),
|
||||
|
||||
@@ -1199,6 +1199,9 @@ class TestUNet2DConditionIPAdapter(UNet2DConditionTesterConfig, IPAdapterTesterM
|
||||
class TestUNet2DConditionModelCompile(UNet2DConditionTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for UNet2DConditionModel."""
|
||||
|
||||
def test_torch_compile_repeated_blocks(self):
|
||||
return super().test_torch_compile_repeated_blocks(recompile_limit=2)
|
||||
|
||||
|
||||
class TestUNet2DConditionModelLoRAHotSwap(UNet2DConditionTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for UNet2DConditionModel."""
|
||||
|
||||
Reference in New Issue
Block a user