Compare commits

...

17 Commits

Author SHA1 Message Date
Dhruv Nair
e41ca61479 Merge branch 'main' into cp-fix 2025-11-10 15:21:33 +05:30
Dhruv Nair
197dd5f312 Merge branch 'main' into cp-fix 2025-11-06 17:46:31 +05:30
Sayak Paul
3dcc9ca73a Merge branch 'main' into cp-fix 2025-11-04 07:20:44 +05:30
Sayak Paul
d65f857d63 Merge branch 'main' into cp-fix 2025-10-31 13:15:51 +05:30
DN6
3b12a0b77d update 2025-10-30 22:31:24 +05:30
DN6
450564563e update 2025-10-30 22:31:24 +05:30
Dhruv Nair
56114f46cc Merge branch 'main' into cp-fix 2025-10-30 08:08:08 +05:30
DN6
fb15ff526f update 2025-10-08 14:35:36 +05:30
DN6
5bfc7dd419 update 2025-10-07 18:40:24 +05:30
DN6
f92578342f update 2025-10-07 17:47:02 +05:30
DN6
8018a6a733 update 2025-10-07 17:45:42 +05:30
DN6
0845ca07d3 update 2025-10-07 17:37:50 +05:30
DN6
881e262c08 update 2025-10-07 17:35:04 +05:30
DN6
a66787b62b update 2025-10-07 17:00:10 +05:30
DN6
1d76322675 update 2025-10-07 16:54:14 +05:30
DN6
428399b590 update 2025-10-07 15:54:04 +05:30
DN6
faf61a4877 update 2025-10-07 14:42:35 +05:30
3 changed files with 109 additions and 76 deletions

View File

@@ -44,11 +44,16 @@ class ContextParallelConfig:
Args:
ring_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
ulysses_degree (`int`, *optional*, defaults to `1`):
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
total number of devices in the context parallel mesh.
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
good interconnect bandwidth.
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
Whether to convert output and LSE to float32 for ring attention numerical stability.
rotate_method (`str`, *optional*, defaults to `"allgather"`):
@@ -79,29 +84,46 @@ class ContextParallelConfig:
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
@property
def mesh_shape(self) -> Tuple[int, int]:
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
if self.ulysses_degree * self.ring_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
@@ -119,7 +141,7 @@ class ParallelConfig:
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
@@ -127,14 +149,14 @@ class ParallelConfig:
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True)

View File

@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
_backends = {}
_constraints = {}
_supported_arg_names = {}
_supports_context_parallel = {}
_supports_context_parallel = set()
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
_checks_enabled = DIFFUSERS_ATTN_CHECKS
@@ -237,7 +237,9 @@ class _AttentionBackendRegistry:
cls._backends[backend] = func
cls._constraints[backend] = constraints or []
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
cls._supports_context_parallel[backend] = supports_context_parallel
if supports_context_parallel:
cls._supports_context_parallel.add(backend.value)
return func
return decorator
@@ -251,15 +253,12 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys())
@classmethod
def _is_context_parallel_enabled(
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
def _is_context_parallel_available(
cls,
backend: AttentionBackendName,
) -> bool:
supports_context_parallel = backend in cls._supports_context_parallel
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
supports_context_parallel = backend.value in cls._supports_context_parallel
return supports_context_parallel
@contextlib.contextmanager
@@ -306,14 +305,6 @@ def dispatch_attention_fn(
backend_name = AttentionBackendName(backend)
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
backend_name, parallel_config
):
raise ValueError(
f"Backend {backend_name} either does not support context parallelism or context parallelism "
f"was enabled with a world size of 1."
)
kwargs = {
"query": query,
"key": key,

View File

@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
raise RuntimeError(
"torch.distributed must be available and initialized before calling `enable_parallelism`."
)
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from .attention_processor import Attention, MochiAttention
if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = None
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
if config.context_parallel_config is not None:
for module in self.modules():
if not isinstance(module, attention_classes):
continue
processor = module.processor
if processor is None or not hasattr(processor, "_attention_backend"):
continue
attention_backend = processor._attention_backend
if attention_backend is None:
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
else:
attention_backend = AttentionBackendName(attention_backend)
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
raise ValueError(
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
f"calling `enable_parallelism()`."
)
# All modules use the same attention processor and backend. We don't need to
# iterate over all modules after checking the first processor
break
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,
)
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
if config.context_parallel_config is not None:
apply_context_parallel(self, config.context_parallel_config, cp_plan)
config.setup(rank, world_size, device, mesh=mesh)
self._parallel_config = config
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():
if not isinstance(module, attention_classes):
continue
@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
continue
processor._parallel_config = config
if config.context_parallel_config is not None:
if cp_plan is None and self._cp_plan is None:
raise ValueError(
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
)
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
apply_context_parallel(self, config.context_parallel_config, cp_plan)
@classmethod
def _load_pretrained_model(
cls,