Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
d7e73dc237 migrate group offloading tests to pytest 2026-03-09 18:50:37 +05:30

View File

@@ -14,16 +14,15 @@
import contextlib
import gc
import unittest
import logging
import pytest
import torch
from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.hooks import HookRegistry, ModelHook
from diffusers.models import ModelMixin
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.utils import get_logger
from diffusers.utils.import_utils import compare_versions
from ..testing_utils import (
@@ -219,20 +218,18 @@ class NestedContainer(torch.nn.Module):
@require_torch_accelerator
class GroupOffloadTests(unittest.TestCase):
class TestGroupOffload:
in_features = 64
hidden_features = 256
out_features = 64
num_layers = 4
def setUp(self):
def setup_method(self):
with torch.no_grad():
self.model = self.get_model()
self.input = torch.randn((4, self.in_features)).to(torch_device)
def tearDown(self):
super().tearDown()
def teardown_method(self):
del self.model
del self.input
gc.collect()
@@ -248,18 +245,20 @@ class GroupOffloadTests(unittest.TestCase):
num_layers=self.num_layers,
)
@pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"],
reason="Test requires a CUDA or XPU device.",
)
def test_offloading_forward_pass(self):
@torch.no_grad()
def run_forward(model):
gc.collect()
backend_empty_cache(torch_device)
backend_reset_peak_memory_stats(torch_device)
self.assertTrue(
all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
assert all(
module._diffusers_hook.get_hook("group_offloading") is not None
for module in model.modules()
if hasattr(module, "_diffusers_hook")
)
model.eval()
output = model(self.input)[0].cpu()
@@ -291,41 +290,37 @@ class GroupOffloadTests(unittest.TestCase):
output_with_group_offloading5, mem5 = run_forward(model)
# Precision assertions - offloading should not impact the output
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5))
self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5))
assert torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)
assert torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)
# Memory assertions - offloading should reduce memory usage
self.assertTrue(mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline)
assert mem4 <= mem5 < mem2 <= mem3 < mem1 < mem_baseline
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self):
def test_warning_logged_if_group_offloaded_module_moved_to_accelerator(self, caplog):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.models.modeling_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
with caplog.at_level(logging.WARNING, logger="diffusers.models.modeling_utils"):
self.model.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self):
def test_warning_logged_if_group_offloaded_pipe_moved_to_accelerator(self, caplog):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
pipe = DummyPipeline(self.model)
self.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
logger = get_logger("diffusers.pipelines.pipeline_utils")
logger.setLevel("INFO")
with self.assertLogs(logger, level="WARNING") as cm:
with caplog.at_level(logging.WARNING, logger="diffusers.pipelines.pipeline_utils"):
pipe.to(torch_device)
self.assertIn(f"The module '{self.model.__class__.__name__}' is group offloaded", cm.output[0])
assert f"The module '{self.model.__class__.__name__}' is group offloaded" in caplog.text
def test_error_raised_if_streams_used_and_no_accelerator_device(self):
torch_accelerator_module = getattr(torch, torch_device, torch.cuda)
original_is_available = torch_accelerator_module.is_available
torch_accelerator_module.is_available = lambda: False
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
self.model.enable_group_offload(
onload_device=torch.device(torch_device), offload_type="leaf_level", use_stream=True
)
@@ -333,31 +328,31 @@ class GroupOffloadTests(unittest.TestCase):
def test_error_raised_if_supports_group_offloading_false(self):
self.model._supports_group_offloading = False
with self.assertRaisesRegex(ValueError, "does not support group offloading"):
with pytest.raises(ValueError, match="does not support group offloading"):
self.model.enable_group_offload(onload_device=torch.device(torch_device))
def test_error_raised_if_model_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
pipe.enable_model_cpu_offload()
def test_error_raised_if_sequential_offloading_applied_on_group_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
with self.assertRaisesRegex(ValueError, "You are trying to apply model/sequential CPU offloading"):
with pytest.raises(ValueError, match="You are trying to apply model/sequential CPU offloading"):
pipe.enable_sequential_cpu_offload()
def test_error_raised_if_group_offloading_applied_on_model_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_model_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
with pytest.raises(ValueError, match="Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_error_raised_if_group_offloading_applied_on_sequential_offloaded_module(self):
pipe = DummyPipeline(self.model)
pipe.enable_sequential_cpu_offload()
with self.assertRaisesRegex(ValueError, "Cannot apply group offloading"):
with pytest.raises(ValueError, match="Cannot apply group offloading"):
pipe.model.enable_group_offload(torch_device, offload_type="block_level", num_blocks_per_group=3)
def test_block_level_stream_with_invocation_order_different_from_initialization_order(self):
@@ -376,12 +371,12 @@ class GroupOffloadTests(unittest.TestCase):
context = contextlib.nullcontext()
if compare_versions("diffusers", "<=", "0.33.0"):
# Will raise a device mismatch RuntimeError mentioning weights are on CPU but input is on device
context = self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device")
context = pytest.raises(RuntimeError, match="Expected all tensors to be on the same device")
with context:
model(self.input)
@parameterized.expand([("block_level",), ("leaf_level",)])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_block_level_offloading_with_parameter_only_module_group(self, offload_type: str):
if torch.device(torch_device).type not in ["cuda", "xpu"]:
return
@@ -407,14 +402,14 @@ class GroupOffloadTests(unittest.TestCase):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match.")
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match."
num_repeats = 2
for i in range(num_repeats):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations.")
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match after multiple invocations."
for (ref_name, ref_module), (name, module) in zip(model_ref.named_modules(), model.named_modules()):
assert ref_name == name
@@ -428,9 +423,7 @@ class GroupOffloadTests(unittest.TestCase):
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
self.assertLess(
cumulated_absmax, 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
)
assert cumulated_absmax < 1e-5, f"Output differences for {name} exceeded threshold: {cumulated_absmax:.5f}"
def test_vae_like_model_without_streams(self):
"""Test VAE-like model with block-level offloading but without streams."""
@@ -452,9 +445,7 @@ class GroupOffloadTests(unittest.TestCase):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
)
assert torch.allclose(out_ref, out, atol=1e-5), "Outputs do not match for VAE-like model without streams."
def test_model_with_only_standalone_layers(self):
"""Test that models with only standalone layers (no ModuleList/Sequential) work with block-level offloading."""
@@ -475,12 +466,11 @@ class GroupOffloadTests(unittest.TestCase):
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for model with standalone layers.",
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match at iteration {i} for model with standalone layers."
)
@parameterized.expand([("block_level",), ("leaf_level",)])
@pytest.mark.parametrize("offload_type", ["block_level", "leaf_level"])
def test_standalone_conv_layers_with_both_offload_types(self, offload_type: str):
"""Test that standalone Conv2d layers work correctly with both block-level and leaf-level offloading."""
if torch.device(torch_device).type not in ["cuda", "xpu"]:
@@ -501,9 +491,8 @@ class GroupOffloadTests(unittest.TestCase):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match for standalone Conv layers with {offload_type}.",
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match for standalone Conv layers with {offload_type}."
)
def test_multiple_invocations_with_vae_like_model(self):
@@ -526,7 +515,7 @@ class GroupOffloadTests(unittest.TestCase):
for i in range(2):
out_ref = model_ref(x).sample
out = model(x).sample
self.assertTrue(torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}.")
assert torch.allclose(out_ref, out, atol=1e-5), f"Outputs do not match at iteration {i}."
def test_nested_container_parameters_offloading(self):
"""Test that parameters from non-computational layers in nested containers are handled correctly."""
@@ -547,9 +536,8 @@ class GroupOffloadTests(unittest.TestCase):
for i in range(2):
out_ref = model_ref(x)
out = model(x)
self.assertTrue(
torch.allclose(out_ref, out, atol=1e-5),
f"Outputs do not match at iteration {i} for nested parameters.",
assert torch.allclose(out_ref, out, atol=1e-5), (
f"Outputs do not match at iteration {i} for nested parameters."
)
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@@ -602,7 +590,7 @@ class DummyModelWithConditionalModules(ModelMixin):
return x
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
class TestConditionalModuleGroupOffload(TestGroupOffload):
"""Tests for conditionally-executed modules under group offloading with streams.
Regression tests for the case where a module is not executed during the first forward pass
@@ -620,10 +608,10 @@ class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
num_layers=self.num_layers,
)
@parameterized.expand([("leaf_level",), ("block_level",)])
@unittest.skipIf(
@pytest.mark.parametrize("offload_type", ["leaf_level", "block_level"])
@pytest.mark.skipif(
torch.device(torch_device).type not in ["cuda", "xpu"],
"Test requires a CUDA or XPU device.",
reason="Test requires a CUDA or XPU device.",
)
def test_conditional_modules_with_stream(self, offload_type: str):
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
@@ -670,23 +658,20 @@ class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
out_ref_no_opt = model_ref(x, optional_input=None)
out_no_opt = model(x, optional_input=None)
self.assertTrue(
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
assert torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5), (
f"[{offload_type}] Outputs do not match on first pass (no optional_input)."
)
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
out_ref_with_opt = model_ref(x, optional_input=optional_input)
out_with_opt = model(x, optional_input=optional_input)
self.assertTrue(
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
assert torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5), (
f"[{offload_type}] Outputs do not match on second pass (with optional_input)."
)
# Third pass again without optional_input — verify stable behavior.
out_ref_no_opt2 = model_ref(x, optional_input=None)
out_no_opt2 = model(x, optional_input=None)
self.assertTrue(
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
assert torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5), (
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input)."
)