Compare commits

..

3 Commits

Author SHA1 Message Date
DN6
eff791831f update 2026-03-13 10:28:38 +05:30
Dhruv Nair
07c5ba8eee [Context Parallel] Add support for custom device mesh (#13064)
* add custom mesh support

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-11 16:42:11 +05:30
Dhruv Nair
897aed72fa [Quantization] Deprecate Quanto (#13180)
* update

* update
2026-03-11 09:26:46 +05:30
7 changed files with 183 additions and 37 deletions

View File

@@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -14,6 +14,7 @@
import importlib
import inspect
import os
import shutil
import sys
import traceback
import warnings
@@ -1883,6 +1884,36 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def _maybe_save_custom_code(self, save_directory: str | os.PathLike):
"""Save custom code files (blocks config and Python modules) to the save directory."""
if self._blocks is None:
return
blocks_module = type(self._blocks).__module__
is_custom_code = not blocks_module.startswith("diffusers.") and blocks_module != "diffusers"
if not is_custom_code:
return
os.makedirs(save_directory, exist_ok=True)
self._blocks.save_pretrained(save_directory)
source_file = inspect.getfile(type(self._blocks))
module_file = os.path.basename(source_file)
dest_file = os.path.join(save_directory, module_file)
if os.path.abspath(source_file) != os.path.abspath(dest_file):
shutil.copyfile(source_file, dest_file)
from ..utils.dynamic_modules_utils import get_relative_import_files
for rel_file in get_relative_import_files(source_file):
rel_name = os.path.relpath(rel_file, os.path.dirname(source_file))
rel_dest = os.path.join(save_directory, rel_name)
if os.path.abspath(rel_file) != os.path.abspath(rel_dest):
os.makedirs(os.path.dirname(rel_dest), exist_ok=True)
shutil.copyfile(rel_file, rel_dest)
def save_pretrained(
self,
save_directory: str | os.PathLike,
@@ -1998,6 +2029,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self._maybe_save_custom_code(save_directory)
self.save_config(save_directory=save_directory)
if push_to_hub:

View File

@@ -36,7 +36,7 @@ from typing import Any, Callable
from packaging import version
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
if is_torch_available():
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
modules_to_not_convert: list[str] | None = None,
**kwargs,
):
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
deprecate("QuantoConfig", "1.0.0", deprecation_message)
self.quant_method = QuantizationMethod.QUANTO
self.weights_dtype = weights_dtype
self.modules_to_not_convert = modules_to_not_convert

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
from diffusers.utils.import_utils import is_optimum_quanto_version
from ...utils import (
deprecate,
get_module_from_name,
is_accelerate_available,
is_accelerate_version,
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
if not is_optimum_quanto_available():
raise ImportError(
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"

View File

@@ -13,8 +13,8 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from diffusers.hooks import HookRegistry, ModelHook
@@ -134,18 +134,20 @@ class SkipLayerHook(ModelHook):
return output
class TestHooks:
class HookTests(unittest.TestCase):
in_features = 4
hidden_features = 8
out_features = 4
num_layers = 2
def setup_method(self):
def setUp(self):
params = self.get_module_parameters()
self.model = DummyModel(**params)
self.model.to(torch_device)
def teardown_method(self):
def tearDown(self):
super().tearDown()
del self.model
gc.collect()
free_memory()
@@ -169,20 +171,20 @@ class TestHooks:
registry_repr = repr(registry)
expected_repr = "HookRegistry(\n (0) add_hook - AddHook\n (1) multiply_hook - MultiplyHook(value=2)\n)"
assert len(registry.hooks) == 2
assert registry._hook_order == ["add_hook", "multiply_hook"]
assert registry_repr == expected_repr
self.assertEqual(len(registry.hooks), 2)
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
self.assertEqual(registry_repr, expected_repr)
registry.remove_hook("add_hook")
assert len(registry.hooks) == 1
assert registry._hook_order == ["multiply_hook"]
self.assertEqual(len(registry.hooks), 1)
self.assertEqual(registry._hook_order, ["multiply_hook"])
def test_stateful_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
registry.register_hook(StatefulAddHook(1), "stateful_add_hook")
assert registry.hooks["stateful_add_hook"].increment == 0
self.assertEqual(registry.hooks["stateful_add_hook"].increment, 0)
input = torch.randn(1, 4, device=torch_device, generator=self.get_generator())
num_repeats = 3
@@ -192,13 +194,13 @@ class TestHooks:
if i == 0:
output1 = result
assert registry.get_hook("stateful_add_hook").increment == num_repeats
self.assertEqual(registry.get_hook("stateful_add_hook").increment, num_repeats)
registry.reset_stateful_hooks()
output2 = self.model(input)
assert registry.get_hook("stateful_add_hook").increment == 1
assert torch.allclose(output1, output2)
self.assertEqual(registry.get_hook("stateful_add_hook").increment, 1)
self.assertTrue(torch.allclose(output1, output2))
def test_inference(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -216,9 +218,9 @@ class TestHooks:
new_input = input * 2 + 1
output3 = self.model(new_input).mean().detach().cpu().item()
assert output1 == pytest.approx(output2, abs=5e-6)
assert output1 == pytest.approx(output3, abs=5e-6)
assert output2 == pytest.approx(output3, abs=5e-6)
self.assertAlmostEqual(output1, output2, places=5)
self.assertAlmostEqual(output1, output3, places=5)
self.assertAlmostEqual(output2, output3, places=5)
def test_skip_layer_hook(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -226,29 +228,30 @@ class TestHooks:
input = torch.zeros(1, 4, device=torch_device)
output = self.model(input).mean().detach().cpu().item()
assert output == 0.0
self.assertEqual(output, 0.0)
registry.remove_hook("skip_layer_hook")
registry.register_hook(SkipLayerHook(skip_layer=False), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
assert output != 0.0
self.assertNotEqual(output, 0.0)
def test_skip_layer_internal_block(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model.linear_1)
input = torch.zeros(1, 4, device=torch_device)
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
with pytest.raises(RuntimeError, match="mat1 and mat2 shapes cannot be multiplied"):
with self.assertRaises(RuntimeError) as cm:
self.model(input).mean().detach().cpu().item()
self.assertIn("mat1 and mat2 shapes cannot be multiplied", str(cm.exception))
registry.remove_hook("skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
assert output != 0.0
self.assertNotEqual(output, 0.0)
registry = HookRegistry.check_if_exists_or_initialize(self.model.blocks[1])
registry.register_hook(SkipLayerHook(skip_layer=True), "skip_layer_hook")
output = self.model(input).mean().detach().cpu().item()
assert output != 0.0
self.assertNotEqual(output, 0.0)
def test_invocation_order_stateful_first(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -275,7 +278,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -286,7 +289,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_middle(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -313,7 +316,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -324,7 +327,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook_2")
with CaptureLogger(logger) as cap_logger:
@@ -333,7 +336,7 @@ class TestHooks:
expected_invocation_order_log = (
("MultiplyHook pre_forward\nMultiplyHook post_forward\n").replace(" ", "").replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
def test_invocation_order_stateful_last(self):
registry = HookRegistry.check_if_exists_or_initialize(self.model)
@@ -360,7 +363,7 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)
registry.remove_hook("add_hook")
with CaptureLogger(logger) as cap_logger:
@@ -371,4 +374,4 @@ class TestHooks:
.replace(" ", "")
.replace("\n", "")
)
assert output == expected_invocation_order_log
self.assertEqual(output, expected_invocation_order_log)

View File

@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)