Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
d7e73dc237 migrate group offloading tests to pytest 2026-03-09 18:50:37 +05:30

View File

@@ -14,16 +14,15 @@
import contextlib import contextlib
import gc import gc
import unittest import logging
import pytest
import torch import torch
from parameterized import parameterized
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions from diffusers.utils.import_utils import compare_versions
from ..testing_utils import ( from ..testing_utils import (
@@ -219,20 +218,18 @@ class NestedContainer(torch.nn.Module):
@require_torch_accelerator @require_torch_accelerator
class GroupOffloadTests(unittest.TestCase): class TestGroupOffload:
in_features = 64 in_features = 64
hidden_features = 256 hidden_features = 256
out_features = 64 out_features = 64
num_layers = 4 num_layers = 4
def setUp(self): def setup_method(self):
with torch.no_grad(): with torch.no_grad():
self.model = self.get_model() self.model = self.get_model()
self.input = torch.randn((4, self.in_features)).to(torch_device) self.input = torch.randn((4, self.in_features)).to(torch_device)
def tearDown(self): def teardown_method(self):
super().tearDown()
del self.model del self.model
del self.input del self.input
gc.collect() gc.collect()
@@ -248,18 +245,20 @@ class GroupOffloadTests(unittest.TestCase):
num_layers=self.num_layers, num_layers=self.num_layers,
) )
@pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"],
reason="Test requires a CUDA or XPU device.",
)
def test_offloading_forward_pass(self): def test_offloading_forward_pass(self):
@torch.no_grad() @torch.no_grad()
def run_forward(model): def run_forward(model):
gc.collect() gc.collect()
backend_empty_cache(torch_device) backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device) backend_reset_peak_memory_stats(torch_device)
self.assertTrue( assert all(
all( module._diffusers_hook.get_hook("group_offloading") is not None
module._diffusers_hook.get_hook("group_offloading") is not None for module in model.modules()
for module in model.modules() if hasattr(module, "_diffusers_hook")
if hasattr(module, "_diffusers_hook")
)
) )
model.eval() model.eval()
output = model(self.input)[0].cpu() output = model(self.input)[0].cpu()
@@ -291,41 +290,37 @@ class GroupOffloadTests(unittest.TestCase):
output_with_group_offloading5, mem5 = run_forward(model) output_with_group_offloading5, mem5 = run_forward(model)
# Precision assertions - offloading should not impact the output # Precision assertions - offloading should not impact the output
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)) assert torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)
# Memory assertions - offloading should reduce memory usage # Memory assertions - offloading should reduce memory usage
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline) assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self): def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self, caplog):
if torch.device(torch_device).type not in ["cuda", "xpu"]: if torch.device(torch_device).type not in ["cuda", "xpu"]:
return return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.models.modeling_utils") with caplog.at_level(logging.WARNING, logger="diffusers.models.modeling_utils"):
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
self.model.to(torch_device) self.model.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0]) assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self): def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self, caplog):
if torch.device(torch_device).type not in ["cuda", "xpu"]: if torch.device(torch_device).type not in ["cuda", "xpu"]:
return return
pipe = DummyPipeline(self.model) pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.pipelines.pipeline_utils") with caplog.at_level(logging.WARNING, logger="diffusers.pipelines.pipeline_utils"):
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
pipe.to(torch_device) pipe.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0]) assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
def test_error_raised_if_streams_used_and_no_accelerator_device(self): def test_error_raised_if_streams_used_and_no_accelerator_device(self):
torch_accelerator_module = getattr(torch, torch_device, torch.cuda) torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
original_is_available = torch_accelerator_module.is_available original_is_available = torch_accelerator_module.is_available
torch_accelerator_module.is_available = lambda: False torch_accelerator_module.is_available = lambda: False
with self.assertRaises(ValueError): with pytest.raises(ValueError):
self.model.enable_group_offload( self.model.enable_group_offload(
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
) )
@@ -333,31 +328,31 @@ class GroupOffloadTests(unittest.TestCase):
def test_error_raised_if_supports_group_offloading_false(self): def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"): with pytest.raises(ValueError, match="does not support group offloading"):
self.model.enable_group_offload(onload_device=torch.device(torch_device)) self.model.enable_group_offload(onload_device=torch.device(torch_device))
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self): def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model) pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"): with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self): def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model) pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"): with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self): def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
pipe = DummyPipeline(self.model) pipe = DummyPipeline(self.model)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): with pytest.raises(ValueError, match="Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self): def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
pipe = DummyPipeline(self.model) pipe = DummyPipeline(self.model)
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"): with pytest.raises(ValueError, match="Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3) pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self): def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
@@ -376,12 +371,12 @@ class GroupOffloadTests(unittest.TestCase):
context = contextlib.nullcontext() context = contextlib.nullcontext()
if compare_versions("diffusers", "<=", "0.33.0"): if compare_versions("diffusers", "<=", "0.33.0"):
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device # Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device") context = pytest.raises(RuntimeError, match="Expected all tensors to be on the same device")
with context: with context:
model(self.input) model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)]) @pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str): def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]: if torch.device(torch_device).type not in ["cuda", "xpu"]:
return return
@@ -407,14 +402,14 @@ class GroupOffloadTests(unittest.TestCase):
out_ref = model_ref(x) out_ref = model_ref(x)
out = model(x) out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.") assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match."
num_repeats = 2 num_repeats = 2
for i in range(num_repeats): for i in range(num_repeats):
out_ref = model_ref(x) out_ref = model_ref(x)
out = model(x) out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.") assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations."
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()): for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name assert ref_name == name
@@ -428,9 +423,7 @@ class GroupOffloadTests(unittest.TestCase):
absdiff = diff.abs() absdiff = diff.abs()
absmax = absdiff.max().item() absmax = absdiff.max().item()
cumulated_absmax += absmax cumulated_absmax += absmax
self.assertLess( assert cumulated_absmax < 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
def test_vae_like_model_without_streams(self): def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams.""" """Test VAE-like model with block-level offloading but without streams."""
@@ -452,9 +445,7 @@ class GroupOffloadTests(unittest.TestCase):
out_ref = model_ref(x).sample out_ref = model_ref(x).sample
out = model(x).sample out = model(x).sample
self.assertTrue( assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)
def test_model_with_only_standalone_layers(self): def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading.""" """Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
@@ -475,12 +466,11 @@ class GroupOffloadTests(unittest.TestCase):
for i in range(2): for i in range(2):
out_ref = model_ref(x) out_ref = model_ref(x)
out = model(x) out = model(x)
self.assertTrue( assert torch.allclose(out_ref, out, atol=1e-5), (
torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i} for model with standalone layers."
f"Outputs do not match at iteration {i} for model with standalone layers.",
) )
@parameterized.expand([("block_level",), ("leaf_level",)]) @pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str): def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading.""" """Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]: if torch.device(torch_device).type not in ["cuda", "xpu"]:
@@ -501,9 +491,8 @@ class GroupOffloadTests(unittest.TestCase):
out_ref = model_ref(x).sample out_ref = model_ref(x).sample
out = model(x).sample out = model(x).sample
self.assertTrue( assert torch.allclose(out_ref, out, atol=1e-5), (
torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match for standalone Conv layers with {offload_type}."
f"Outputs do not match for standalone Conv layers with {offload_type}.",
) )
def test_multiple_invocations_with_vae_like_model(self): def test_multiple_invocations_with_vae_like_model(self):
@@ -526,7 +515,7 @@ class GroupOffloadTests(unittest.TestCase):
for i in range(2): for i in range(2):
out_ref = model_ref(x).sample out_ref = model_ref(x).sample
out = model(x).sample out = model(x).sample
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.") assert torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}."
def test_nested_container_parameters_offloading(self): def test_nested_container_parameters_offloading(self):
"""Test that parameters from non-computational layers in nested containers are handled correctly.""" """Test that parameters from non-computational layers in nested containers are handled correctly."""
@@ -547,9 +536,8 @@ class GroupOffloadTests(unittest.TestCase):
for i in range(2): for i in range(2):
out_ref = model_ref(x) out_ref = model_ref(x)
out = model(x) out = model(x)
self.assertTrue( assert torch.allclose(out_ref, out, atol=1e-5), (
torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i} for nested parameters."
f"Outputs do not match at iteration {i} for nested parameters.",
) )
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None): def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@@ -602,7 +590,7 @@ class DummyModelWithConditionalModules(ModelMixin):
return x return x
class ConditionalModuleGroupOffloadTests(GroupOffloadTests): class TestConditionalModuleGroupOffload(TestGroupOffload):
"""Tests for conditionally-executed modules under group offloading with streams. """Tests for conditionally-executed modules under group offloading with streams.
Regression tests for the case where a module is not executed during the first forward pass Regression tests for the case where a module is not executed during the first forward pass
@@ -620,10 +608,10 @@ class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
num_layers=self.num_layers, num_layers=self.num_layers,
) )
@parameterized.expand([("leaf_level",), ("block_level",)]) @pytest.mark.parametrize("offload_type", ["leaf_level", "block_level"])
@unittest.skipIf( @pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"], torch.device(torch_device).type not in ["cuda", "xpu"],
"Test requires a CUDA or XPU device.", reason="Test requires a CUDA or XPU device.",
) )
def test_conditional_modules_with_stream(self, offload_type: str): def test_conditional_modules_with_stream(self, offload_type: str):
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams. """Regression test: conditionally-executed modules must not cause device mismatch when using streams.
@@ -670,23 +658,20 @@ class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
# execution order is traced. optional_proj_1/2 are NOT in the traced order. # execution order is traced. optional_proj_1/2 are NOT in the traced order.
out_ref_no_opt = model_ref(x, optional_input=None) out_ref_no_opt = model_ref(x, optional_input=None)
out_no_opt = model(x, optional_input=None) out_no_opt = model(x, optional_input=None)
self.assertTrue( assert torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), (
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), f"[{offload_type}] Outputs do not match on first pass (no optional_input)."
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
) )
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called. # Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
out_ref_with_opt = model_ref(x, optional_input=optional_input) out_ref_with_opt = model_ref(x, optional_input=optional_input)
out_with_opt = model(x, optional_input=optional_input) out_with_opt = model(x, optional_input=optional_input)
self.assertTrue( assert torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), (
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), f"[{offload_type}] Outputs do not match on second pass (with optional_input)."
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
) )
# Third pass again without optional_input — verify stable behavior. # Third pass again without optional_input — verify stable behavior.
out_ref_no_opt2 = model_ref(x, optional_input=None) out_ref_no_opt2 = model_ref(x, optional_input=None)
out_no_opt2 = model(x, optional_input=None) out_no_opt2 = model(x, optional_input=None)
self.assertTrue( assert torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), (
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), f"[{offload_type}] Outputs do not match on third pass (back to no optional_input)."
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
) )