mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-20 19:34:48 +08:00
Compare commits
17 Commits
qwenimage-
...
cp-fix
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e41ca61479 | ||
|
|
197dd5f312 | ||
|
|
3dcc9ca73a | ||
|
|
d65f857d63 | ||
|
|
3b12a0b77d | ||
|
|
450564563e | ||
|
|
56114f46cc | ||
|
|
fb15ff526f | ||
|
|
5bfc7dd419 | ||
|
|
f92578342f | ||
|
|
8018a6a733 | ||
|
|
0845ca07d3 | ||
|
|
881e262c08 | ||
|
|
a66787b62b | ||
|
|
1d76322675 | ||
|
|
428399b590 | ||
|
|
faf61a4877 |
@@ -44,11 +44,16 @@ class ContextParallelConfig:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
ring_degree (`int`, *optional*, defaults to `1`):
|
ring_degree (`int`, *optional*, defaults to `1`):
|
||||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||||
total number of devices in the context parallel mesh.
|
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||||
|
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||||
|
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||||
|
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
|
||||||
total number of devices in the context parallel mesh.
|
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
|
||||||
|
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
|
||||||
|
good interconnect bandwidth.
|
||||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||||
@@ -79,28 +84,45 @@ class ContextParallelConfig:
|
|||||||
if self.ulysses_degree is None:
|
if self.ulysses_degree is None:
|
||||||
self.ulysses_degree = 1
|
self.ulysses_degree = 1
|
||||||
|
|
||||||
|
if self.ring_degree == 1 and self.ulysses_degree == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
|
||||||
|
)
|
||||||
|
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||||
|
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||||
|
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||||
|
)
|
||||||
|
if self.rotate_method != "allgather":
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mesh_shape(self) -> Tuple[int, int]:
|
||||||
|
return (self.ring_degree, self.ulysses_degree)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mesh_dim_names(self) -> Tuple[str, str]:
|
||||||
|
"""Dimension names for the device mesh."""
|
||||||
|
return ("ring", "ulysses")
|
||||||
|
|
||||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._world_size = world_size
|
self._world_size = world_size
|
||||||
self._device = device
|
self._device = device
|
||||||
self._mesh = mesh
|
self._mesh = mesh
|
||||||
if self.ring_degree is None:
|
|
||||||
self.ring_degree = 1
|
if self.ulysses_degree * self.ring_degree > world_size:
|
||||||
if self.ulysses_degree is None:
|
raise ValueError(
|
||||||
self.ulysses_degree = 1
|
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||||
if self.rotate_method != "allgather":
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
|
||||||
)
|
)
|
||||||
if self._flattened_mesh is None:
|
|
||||||
self._flattened_mesh = self._mesh._flatten()
|
self._flattened_mesh = self._mesh._flatten()
|
||||||
if self._ring_mesh is None:
|
|
||||||
self._ring_mesh = self._mesh["ring"]
|
self._ring_mesh = self._mesh["ring"]
|
||||||
if self._ulysses_mesh is None:
|
|
||||||
self._ulysses_mesh = self._mesh["ulysses"]
|
self._ulysses_mesh = self._mesh["ulysses"]
|
||||||
if self._ring_local_rank is None:
|
|
||||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||||
if self._ulysses_local_rank is None:
|
|
||||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||||
|
|
||||||
|
|
||||||
@@ -119,7 +141,7 @@ class ParallelConfig:
|
|||||||
_rank: int = None
|
_rank: int = None
|
||||||
_world_size: int = None
|
_world_size: int = None
|
||||||
_device: torch.device = None
|
_device: torch.device = None
|
||||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||||
|
|
||||||
def setup(
|
def setup(
|
||||||
self,
|
self,
|
||||||
@@ -127,14 +149,14 @@ class ParallelConfig:
|
|||||||
world_size: int,
|
world_size: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
*,
|
*,
|
||||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||||
):
|
):
|
||||||
self._rank = rank
|
self._rank = rank
|
||||||
self._world_size = world_size
|
self._world_size = world_size
|
||||||
self._device = device
|
self._device = device
|
||||||
self._cp_mesh = cp_mesh
|
self._mesh = mesh
|
||||||
if self.context_parallel_config is not None:
|
if self.context_parallel_config is not None:
|
||||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
self.context_parallel_config.setup(rank, world_size, device, mesh)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
|
|||||||
_backends = {}
|
_backends = {}
|
||||||
_constraints = {}
|
_constraints = {}
|
||||||
_supported_arg_names = {}
|
_supported_arg_names = {}
|
||||||
_supports_context_parallel = {}
|
_supports_context_parallel = set()
|
||||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||||
|
|
||||||
@@ -237,7 +237,9 @@ class _AttentionBackendRegistry:
|
|||||||
cls._backends[backend] = func
|
cls._backends[backend] = func
|
||||||
cls._constraints[backend] = constraints or []
|
cls._constraints[backend] = constraints or []
|
||||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
if supports_context_parallel:
|
||||||
|
cls._supports_context_parallel.add(backend.value)
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -251,15 +253,12 @@ class _AttentionBackendRegistry:
|
|||||||
return list(cls._backends.keys())
|
return list(cls._backends.keys())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _is_context_parallel_enabled(
|
def _is_context_parallel_available(
|
||||||
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
|
cls,
|
||||||
|
backend: AttentionBackendName,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
supports_context_parallel = backend in cls._supports_context_parallel
|
supports_context_parallel = backend.value in cls._supports_context_parallel
|
||||||
is_degree_greater_than_1 = parallel_config is not None and (
|
return supports_context_parallel
|
||||||
parallel_config.context_parallel_config.ring_degree > 1
|
|
||||||
or parallel_config.context_parallel_config.ulysses_degree > 1
|
|
||||||
)
|
|
||||||
return supports_context_parallel and is_degree_greater_than_1
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@@ -306,14 +305,6 @@ def dispatch_attention_fn(
|
|||||||
backend_name = AttentionBackendName(backend)
|
backend_name = AttentionBackendName(backend)
|
||||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||||
|
|
||||||
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
|
|
||||||
backend_name, parallel_config
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
|
||||||
f"was enabled with a world size of 1."
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"key": key,
|
"key": key,
|
||||||
|
|||||||
@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
config: Union[ParallelConfig, ContextParallelConfig],
|
config: Union[ParallelConfig, ContextParallelConfig],
|
||||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||||
):
|
):
|
||||||
from ..hooks.context_parallel import apply_context_parallel
|
|
||||||
from .attention import AttentionModuleMixin
|
|
||||||
from .attention_processor import Attention, MochiAttention
|
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
|
||||||
|
raise RuntimeError(
|
||||||
|
"torch.distributed must be available and initialized before calling `enable_parallelism`."
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..hooks.context_parallel import apply_context_parallel
|
||||||
|
from .attention import AttentionModuleMixin
|
||||||
|
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||||
|
from .attention_processor import Attention, MochiAttention
|
||||||
|
|
||||||
if isinstance(config, ContextParallelConfig):
|
if isinstance(config, ContextParallelConfig):
|
||||||
config = ParallelConfig(context_parallel_config=config)
|
config = ParallelConfig(context_parallel_config=config)
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
|
||||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
|
||||||
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
device_type = torch._C._get_accelerator().type
|
device_type = torch._C._get_accelerator().type
|
||||||
device_module = torch.get_device_module(device_type)
|
device_module = torch.get_device_module(device_type)
|
||||||
device = torch.device(device_type, rank % device_module.device_count())
|
device = torch.device(device_type, rank % device_module.device_count())
|
||||||
|
|
||||||
cp_mesh = None
|
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||||
|
|
||||||
|
if config.context_parallel_config is not None:
|
||||||
|
for module in self.modules():
|
||||||
|
if not isinstance(module, attention_classes):
|
||||||
|
continue
|
||||||
|
|
||||||
|
processor = module.processor
|
||||||
|
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
attention_backend = processor._attention_backend
|
||||||
|
if attention_backend is None:
|
||||||
|
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||||
|
else:
|
||||||
|
attention_backend = AttentionBackendName(attention_backend)
|
||||||
|
|
||||||
|
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
|
||||||
|
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||||
|
raise ValueError(
|
||||||
|
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||||
|
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||||
|
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||||
|
f"calling `enable_parallelism()`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# All modules use the same attention processor and backend. We don't need to
|
||||||
|
# iterate over all modules after checking the first processor
|
||||||
|
break
|
||||||
|
|
||||||
|
mesh = None
|
||||||
if config.context_parallel_config is not None:
|
if config.context_parallel_config is not None:
|
||||||
cp_config = config.context_parallel_config
|
cp_config = config.context_parallel_config
|
||||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
|
||||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
|
||||||
)
|
|
||||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
|
||||||
raise ValueError(
|
|
||||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
|
||||||
)
|
|
||||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
mesh_shape=cp_config.mesh_shape,
|
||||||
mesh_dim_names=("ring", "ulysses"),
|
mesh_dim_names=cp_config.mesh_dim_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
config.setup(rank, world_size, device, mesh=mesh)
|
||||||
|
|
||||||
if cp_plan is None and self._cp_plan is None:
|
|
||||||
raise ValueError(
|
|
||||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
|
||||||
)
|
|
||||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
|
||||||
|
|
||||||
if config.context_parallel_config is not None:
|
|
||||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
|
||||||
|
|
||||||
self._parallel_config = config
|
self._parallel_config = config
|
||||||
|
|
||||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if not isinstance(module, attention_classes):
|
if not isinstance(module, attention_classes):
|
||||||
continue
|
continue
|
||||||
@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
continue
|
continue
|
||||||
processor._parallel_config = config
|
processor._parallel_config = config
|
||||||
|
|
||||||
|
if config.context_parallel_config is not None:
|
||||||
|
if cp_plan is None and self._cp_plan is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||||
|
)
|
||||||
|
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||||
|
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_pretrained_model(
|
def _load_pretrained_model(
|
||||||
cls,
|
cls,
|
||||||
|
|||||||
Reference in New Issue
Block a user