Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
1e6578bbe3 Merge branch 'main' into refactor-caching-tests 2026-03-10 09:25:51 +05:30
Sayak Paul
81aa43271b Merge branch 'main' into refactor-caching-tests 2026-03-10 08:57:11 +05:30
sayakpaul
9239908f5d include taylorseer in the caching mixin. 2026-03-10 08:56:42 +05:30
sayakpaul
9cd3e6ba88 refactor magcache tests. 2026-03-09 19:26:42 +05:30
10 changed files with 225 additions and 397 deletions

View File

@@ -60,16 +60,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -78,7 +68,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -135,7 +124,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -14,7 +14,6 @@
import importlib
import inspect
import os
import shutil
import sys
import traceback
import warnings
@@ -1884,36 +1883,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def _maybe_save_custom_code(self, save_directory: str | os.PathLike):
"""Save custom code files (blocks config and Python modules) to the save directory."""
if self._blocks is None:
return
blocks_module = type(self._blocks).__module__
is_custom_code = not blocks_module.startswith("diffusers.") and blocks_module != "diffusers"
if not is_custom_code:
return
os.makedirs(save_directory, exist_ok=True)
self._blocks.save_pretrained(save_directory)
source_file = inspect.getfile(type(self._blocks))
module_file = os.path.basename(source_file)
dest_file = os.path.join(save_directory, module_file)
if os.path.abspath(source_file) != os.path.abspath(dest_file):
shutil.copyfile(source_file, dest_file)
from ..utils.dynamic_modules_utils import get_relative_import_files
for rel_file in get_relative_import_files(source_file):
rel_name = os.path.relpath(rel_file, os.path.dirname(source_file))
rel_dest = os.path.join(save_directory, rel_name)
if os.path.abspath(rel_file) != os.path.abspath(rel_dest):
os.makedirs(os.path.dirname(rel_dest), exist_ok=True)
shutil.copyfile(rel_file, rel_dest)
def save_pretrained(
self,
save_directory: str | os.PathLike,
@@ -2029,8 +1998,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self._maybe_save_custom_code(save_directory)
self.save_config(save_directory=save_directory)
if push_to_hub:

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,8 +844,6 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,7 +3,6 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -43,9 +42,6 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -1,244 +0,0 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from diffusers import MagCacheConfig, apply_mag_cache
from diffusers.hooks._helpers import TransformerBlockMetadata, TransformerBlockRegistry
from diffusers.models import ModelMixin
from diffusers.utils import logging
logger = logging.get_logger(__name__)
class DummyBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
# Output is double input
# This ensures Residual = 2*Input - Input = Input
return hidden_states * 2.0
class DummyTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([DummyBlock(), DummyBlock()])
def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
return hidden_states
class TupleOutputBlock(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, hidden_states, encoder_hidden_states=None, **kwargs):
# Returns a tuple
return hidden_states * 2.0, encoder_hidden_states
class TupleTransformer(ModelMixin):
def __init__(self):
super().__init__()
self.transformer_blocks = torch.nn.ModuleList([TupleOutputBlock()])
def forward(self, hidden_states, encoder_hidden_states=None):
for block in self.transformer_blocks:
# Emulate Flux-like behavior
output = block(hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = output[0]
encoder_hidden_states = output[1]
return hidden_states, encoder_hidden_states
class MagCacheTests(unittest.TestCase):
def setUp(self):
# Register standard dummy block
TransformerBlockRegistry.register(
DummyBlock,
TransformerBlockMetadata(return_hidden_states_index=None, return_encoder_hidden_states_index=None),
)
# Register tuple block (Flux style)
TransformerBlockRegistry.register(
TupleOutputBlock,
TransformerBlockMetadata(return_hidden_states_index=0, return_encoder_hidden_states_index=1),
)
def _set_context(self, model, context_name):
"""Helper to set context on all hooks in the model."""
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(context_name)
def _get_calibration_data(self, model):
for module in model.modules():
if hasattr(module, "_diffusers_hook"):
hook = module._diffusers_hook.get_hook("mag_cache_block_hook")
if hook:
return hook.state_manager.get_state().calibration_ratios
return []
def test_mag_cache_validation(self):
"""Test that missing mag_ratios raises ValueError."""
with self.assertRaises(ValueError):
MagCacheConfig(num_inference_steps=10, calibrate=False)
def test_mag_cache_skipping_logic(self):
"""
Tests that MagCache correctly calculates residuals and skips blocks when conditions are met.
"""
model = DummyTransformer()
# Dummy ratios: [1.0, 1.0] implies 0 accumulated error if we skip
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(
threshold=100.0,
num_inference_steps=2,
retention_ratio=0.0, # Enable immediate skipping
max_skip_steps=5,
mag_ratios=ratios,
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0: Input 10.0 -> Output 40.0 (2 blocks * 2x each)
# HeadInput=10. Output=40. Residual=30.
input_t0 = torch.tensor([[[10.0]]])
output_t0 = model(input_t0)
self.assertTrue(torch.allclose(output_t0, torch.tensor([[[40.0]]])), "Step 0 failed")
# Step 1: Input 11.0.
# If Skipped: Output = Input(11) + Residual(30) = 41.0
# If Computed: Output = 11 * 4 = 44.0
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
self.assertTrue(
torch.allclose(output_t1, torch.tensor([[[41.0]]])), f"Expected Skip (41.0), got {output_t1.item()}"
)
def test_mag_cache_retention(self):
"""Test that retention_ratio prevents skipping even if error is low."""
model = DummyTransformer()
# Ratios that imply 0 error, so it *would* skip if retention allowed it
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(
threshold=100.0,
num_inference_steps=2,
retention_ratio=1.0, # Force retention for ALL steps
mag_ratios=ratios,
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0
model(torch.tensor([[[10.0]]]))
# Step 1: Should COMPUTE (44.0) not SKIP (41.0) because of retention
input_t1 = torch.tensor([[[11.0]]])
output_t1 = model(input_t1)
self.assertTrue(
torch.allclose(output_t1, torch.tensor([[[44.0]]])),
f"Expected Compute (44.0) due to retention, got {output_t1.item()}",
)
def test_mag_cache_tuple_outputs(self):
"""Test compatibility with models returning (hidden, encoder_hidden) like Flux."""
model = TupleTransformer()
ratios = np.array([1.0, 1.0])
config = MagCacheConfig(threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=ratios)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0: Compute. Input 10.0 -> Output 20.0 (1 block * 2x)
# Residual = 10.0
input_t0 = torch.tensor([[[10.0]]])
enc_t0 = torch.tensor([[[1.0]]])
out_0, _ = model(input_t0, encoder_hidden_states=enc_t0)
self.assertTrue(torch.allclose(out_0, torch.tensor([[[20.0]]])))
# Step 1: Skip. Input 11.0.
# Skipped Output = 11 + 10 = 21.0
input_t1 = torch.tensor([[[11.0]]])
out_1, _ = model(input_t1, encoder_hidden_states=enc_t0)
self.assertTrue(
torch.allclose(out_1, torch.tensor([[[21.0]]])), f"Tuple skip failed. Expected 21.0, got {out_1.item()}"
)
def test_mag_cache_reset(self):
"""Test that state resets correctly after num_inference_steps."""
model = DummyTransformer()
config = MagCacheConfig(
threshold=100.0, num_inference_steps=2, retention_ratio=0.0, mag_ratios=np.array([1.0, 1.0])
)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
input_t = torch.ones(1, 1, 1)
model(input_t) # Step 0
model(input_t) # Step 1 (Skipped)
# Step 2 (Reset -> Step 0) -> Should Compute
# Input 2.0 -> Output 8.0
input_t2 = torch.tensor([[[2.0]]])
output_t2 = model(input_t2)
self.assertTrue(torch.allclose(output_t2, torch.tensor([[[8.0]]])), "State did not reset correctly")
def test_mag_cache_calibration(self):
"""Test that calibration mode records ratios."""
model = DummyTransformer()
config = MagCacheConfig(num_inference_steps=2, calibrate=True)
apply_mag_cache(model, config)
self._set_context(model, "test_context")
# Step 0
# HeadInput = 10. Output = 40. Residual = 30.
# Ratio 0 is placeholder 1.0
model(torch.tensor([[[10.0]]]))
# Check intermediate state
ratios = self._get_calibration_data(model)
self.assertEqual(len(ratios), 1)
self.assertEqual(ratios[0], 1.0)
# Step 1
# HeadInput = 10. Output = 40. Residual = 30.
# PrevResidual = 30. CurrResidual = 30.
# Ratio = 30/30 = 1.0
model(torch.tensor([[[10.0]]]))
# Verify it computes fully (no skip)
# If it skipped, output would be 41.0. It should be 40.0
# Actually in test setup, input is same (10.0) so output 40.0.
# Let's ensure list is empty after reset (end of step 1)
ratios_after = self._get_calibration_data(model)
self.assertEqual(ratios_after, [])

View File

@@ -5,8 +5,12 @@ from .cache import (
FasterCacheTesterMixin,
FirstBlockCacheConfigMixin,
FirstBlockCacheTesterMixin,
MagCacheConfigMixin,
MagCacheTesterMixin,
PyramidAttentionBroadcastConfigMixin,
PyramidAttentionBroadcastTesterMixin,
TaylorSeerCacheConfigMixin,
TaylorSeerCacheTesterMixin,
)
from .common import BaseModelTesterConfig, ModelTesterMixin
from .compile import TorchCompileTesterMixin
@@ -50,6 +54,8 @@ __all__ = [
"FasterCacheTesterMixin",
"FirstBlockCacheConfigMixin",
"FirstBlockCacheTesterMixin",
"MagCacheConfigMixin",
"MagCacheTesterMixin",
"GGUFCompileTesterMixin",
"GGUFConfigMixin",
"GGUFTesterMixin",
@@ -65,6 +71,8 @@ __all__ = [
"ModelTesterMixin",
"PyramidAttentionBroadcastConfigMixin",
"PyramidAttentionBroadcastTesterMixin",
"TaylorSeerCacheConfigMixin",
"TaylorSeerCacheTesterMixin",
"QuantizationCompileTesterMixin",
"QuantizationTesterMixin",
"QuantoCompileTesterMixin",

View File

@@ -18,10 +18,18 @@ import gc
import pytest
import torch
from diffusers.hooks import FasterCacheConfig, FirstBlockCacheConfig, PyramidAttentionBroadcastConfig
from diffusers.hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
MagCacheConfig,
PyramidAttentionBroadcastConfig,
TaylorSeerCacheConfig,
)
from diffusers.hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
from diffusers.hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
from diffusers.hooks.mag_cache import _MAG_CACHE_BLOCK_HOOK, _MAG_CACHE_LEADER_BLOCK_HOOK
from diffusers.hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
from diffusers.hooks.taylorseer_cache import _TAYLORSEER_CACHE_HOOK
from diffusers.models.cache_utils import CacheMixin
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_cache, torch_device
@@ -554,3 +562,192 @@ class FasterCacheTesterMixin(FasterCacheConfigMixin, CacheTesterMixin):
@require_cache_mixin
def test_faster_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class MagCacheConfigMixin:
"""
Base mixin providing MagCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default MagCache config - can be overridden by subclasses.
# Uses neutral ratios [1.0, 1.0] and a high threshold so the second
# inference step is always skipped, which is required by _test_cache_inference.
MAG_CACHE_CONFIG = {
"num_inference_steps": 2,
"retention_ratio": 0.0,
"threshold": 100.0,
"mag_ratios": [1.0, 1.0],
}
def _get_cache_config(self):
return MagCacheConfig(**self.MAG_CACHE_CONFIG)
def _get_hook_names(self):
return [_MAG_CACHE_LEADER_BLOCK_HOOK, _MAG_CACHE_BLOCK_HOOK]
@is_cache
class MagCacheTesterMixin(MagCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing MagCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@require_cache_mixin
def test_mag_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_mag_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_mag_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_mag_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_mag_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_mag_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()
@is_cache
class TaylorSeerCacheConfigMixin:
"""
Base mixin providing TaylorSeerCache config.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
"""
# Default TaylorSeerCache config - can be overridden by subclasses.
# Uses a low cache_interval and disable_cache_before_step=0 so the second
# inference step is always predicted, which is required by _test_cache_inference.
TAYLORSEER_CACHE_CONFIG = {
"cache_interval": 3,
"disable_cache_before_step": 1,
"max_order": 1,
}
def _get_cache_config(self):
return TaylorSeerCacheConfig(**self.TAYLORSEER_CACHE_CONFIG)
def _get_hook_names(self):
return [_TAYLORSEER_CACHE_HOOK]
@is_cache
class TaylorSeerCacheTesterMixin(TaylorSeerCacheConfigMixin, CacheTesterMixin):
"""
Mixin class for testing TaylorSeerCache on models.
Expected class attributes:
- model_class: The model class to test (must use CacheMixin)
Expected methods to be implemented by subclasses:
- get_init_dict(): Returns dict of arguments to initialize the model
- get_dummy_inputs(): Returns dict of inputs to pass to the model forward pass
Pytest mark: cache
Use `pytest -m "not cache"` to skip these tests
"""
@torch.no_grad()
def _test_cache_inference(self):
"""Test that model can run inference with TaylorSeer cache enabled (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
# TaylorSeer requires cache_context to be set for inference
with model.cache_context("taylorseer_test"):
# First pass populates the cache
_ = model(**inputs_dict, return_dict=False)[0]
# Create modified inputs for second pass
inputs_dict_step2 = inputs_dict.copy()
if self.cache_input_key in inputs_dict_step2:
inputs_dict_step2[self.cache_input_key] = inputs_dict_step2[self.cache_input_key] + torch.randn_like(
inputs_dict_step2[self.cache_input_key]
)
# Second pass - TaylorSeer should use cached Taylor series predictions
output_with_cache = model(**inputs_dict_step2, return_dict=False)[0]
assert output_with_cache is not None, "Model output should not be None with cache enabled."
assert not torch.isnan(output_with_cache).any(), "Model output contains NaN with cache enabled."
# Run same inputs without cache to compare
model.disable_cache()
output_without_cache = model(**inputs_dict_step2, return_dict=False)[0]
# Cached output should be different from non-cached output (due to approximation)
assert not torch.allclose(output_without_cache, output_with_cache, atol=1e-5), (
"Cached output should be different from non-cached output due to cache approximation."
)
@torch.no_grad()
def _test_reset_stateful_cache(self):
"""Test that _reset_stateful_cache resets the TaylorSeer cache state (requires cache_context)."""
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
model.eval()
config = self._get_cache_config()
model.enable_cache(config)
with model.cache_context("taylorseer_test"):
_ = model(**inputs_dict, return_dict=False)[0]
model._reset_stateful_cache()
model.disable_cache()
@require_cache_mixin
def test_taylorseer_cache_enable_disable_state(self):
self._test_cache_enable_disable_state()
@require_cache_mixin
def test_taylorseer_cache_double_enable_raises_error(self):
self._test_cache_double_enable_raises_error()
@require_cache_mixin
def test_taylorseer_cache_hooks_registered(self):
self._test_cache_hooks_registered()
@require_cache_mixin
def test_taylorseer_cache_inference(self):
self._test_cache_inference()
@require_cache_mixin
def test_taylorseer_cache_context_manager(self):
self._test_cache_context_manager()
@require_cache_mixin
def test_taylorseer_cache_reset_stateful_cache(self):
self._test_reset_stateful_cache()

View File

@@ -60,7 +60,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -84,59 +89,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -174,48 +126,3 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -37,6 +37,7 @@ from ..testing_utils import (
IPAdapterTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MagCacheTesterMixin,
MemoryTesterMixin,
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
@@ -45,6 +46,7 @@ from ..testing_utils import (
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
TaylorSeerCacheTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
@@ -430,3 +432,11 @@ class TestFluxTransformerFasterCache(FluxTransformerTesterConfig, FasterCacheTes
"tensor_format": "BCHW",
"is_guidance_distilled": True,
}
class TestFluxTransformerMagCache(FluxTransformerTesterConfig, MagCacheTesterMixin):
"""MagCache tests for Flux Transformer."""
class TestFluxTransformerTaylorSeerCache(FluxTransformerTesterConfig, TaylorSeerCacheTesterMixin):
"""TaylorSeerCache tests for Flux Transformer."""