mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 05:24:20 +08:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e41ca61479 | ||
|
|
197dd5f312 | ||
|
|
3dcc9ca73a | ||
|
|
d65f857d63 | ||
|
|
3b12a0b77d | ||
|
|
450564563e | ||
|
|
56114f46cc | ||
|
|
fb15ff526f | ||
|
|
5bfc7dd419 | ||
|
|
f92578342f | ||
|
|
8018a6a733 | ||
|
|
0845ca07d3 | ||
|
|
881e262c08 | ||
|
|
a66787b62b | ||
|
|
1d76322675 | ||
|
|
428399b590 | ||
|
|
faf61a4877 |
@@ -44,11 +44,16 @@ class ContextParallelConfig:
|
||||
|
||||
Args:
|
||||
ring_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ring attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
|
||||
attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
|
||||
of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
|
||||
for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
|
||||
context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
|
||||
ulysses_degree (`int`, *optional*, defaults to `1`):
|
||||
Number of devices to use for ulysses attention within a context parallel region. Must be a divisor of the
|
||||
total number of devices in the context parallel mesh.
|
||||
Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
|
||||
local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
|
||||
KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
|
||||
good interconnect bandwidth.
|
||||
convert_to_fp32 (`bool`, *optional*, defaults to `True`):
|
||||
Whether to convert output and LSE to float32 for ring attention numerical stability.
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
@@ -79,29 +84,46 @@ class ContextParallelConfig:
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
if self.ring_degree == 1 and self.ulysses_degree == 1:
|
||||
raise ValueError(
|
||||
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
)
|
||||
|
||||
@property
|
||||
def mesh_shape(self) -> Tuple[int, int]:
|
||||
return (self.ring_degree, self.ulysses_degree)
|
||||
|
||||
@property
|
||||
def mesh_dim_names(self) -> Tuple[str, str]:
|
||||
"""Dimension names for the device mesh."""
|
||||
return ("ring", "ulysses")
|
||||
|
||||
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._mesh = mesh
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
if self.ulysses_degree * self.ring_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
if self._flattened_mesh is None:
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
if self._ring_mesh is None:
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
if self._ulysses_mesh is None:
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
if self._ring_local_rank is None:
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
if self._ulysses_local_rank is None:
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -119,7 +141,7 @@ class ParallelConfig:
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
_device: torch.device = None
|
||||
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
_mesh: torch.distributed.device_mesh.DeviceMesh = None
|
||||
|
||||
def setup(
|
||||
self,
|
||||
@@ -127,14 +149,14 @@ class ParallelConfig:
|
||||
world_size: int,
|
||||
device: torch.device,
|
||||
*,
|
||||
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
|
||||
):
|
||||
self._rank = rank
|
||||
self._world_size = world_size
|
||||
self._device = device
|
||||
self._cp_mesh = cp_mesh
|
||||
self._mesh = mesh
|
||||
if self.context_parallel_config is not None:
|
||||
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
|
||||
self.context_parallel_config.setup(rank, world_size, device, mesh)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@@ -220,7 +220,7 @@ class _AttentionBackendRegistry:
|
||||
_backends = {}
|
||||
_constraints = {}
|
||||
_supported_arg_names = {}
|
||||
_supports_context_parallel = {}
|
||||
_supports_context_parallel = set()
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
@@ -237,7 +237,9 @@ class _AttentionBackendRegistry:
|
||||
cls._backends[backend] = func
|
||||
cls._constraints[backend] = constraints or []
|
||||
cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys())
|
||||
cls._supports_context_parallel[backend] = supports_context_parallel
|
||||
if supports_context_parallel:
|
||||
cls._supports_context_parallel.add(backend.value)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
@@ -251,15 +253,12 @@ class _AttentionBackendRegistry:
|
||||
return list(cls._backends.keys())
|
||||
|
||||
@classmethod
|
||||
def _is_context_parallel_enabled(
|
||||
cls, backend: AttentionBackendName, parallel_config: Optional["ParallelConfig"]
|
||||
def _is_context_parallel_available(
|
||||
cls,
|
||||
backend: AttentionBackendName,
|
||||
) -> bool:
|
||||
supports_context_parallel = backend in cls._supports_context_parallel
|
||||
is_degree_greater_than_1 = parallel_config is not None and (
|
||||
parallel_config.context_parallel_config.ring_degree > 1
|
||||
or parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
)
|
||||
return supports_context_parallel and is_degree_greater_than_1
|
||||
supports_context_parallel = backend.value in cls._supports_context_parallel
|
||||
return supports_context_parallel
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -306,14 +305,6 @@ def dispatch_attention_fn(
|
||||
backend_name = AttentionBackendName(backend)
|
||||
backend_fn = _AttentionBackendRegistry._backends.get(backend_name)
|
||||
|
||||
if parallel_config is not None and not _AttentionBackendRegistry._is_context_parallel_enabled(
|
||||
backend_name, parallel_config
|
||||
):
|
||||
raise ValueError(
|
||||
f"Backend {backend_name} either does not support context parallelism or context parallelism "
|
||||
f"was enabled with a world size of 1."
|
||||
)
|
||||
|
||||
kwargs = {
|
||||
"query": query,
|
||||
"key": key,
|
||||
|
||||
@@ -1484,59 +1484,71 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
config: Union[ParallelConfig, ContextParallelConfig],
|
||||
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
|
||||
):
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
logger.warning(
|
||||
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
|
||||
raise RuntimeError(
|
||||
"torch.distributed must be available and initialized before calling `enable_parallelism`."
|
||||
)
|
||||
|
||||
from ..hooks.context_parallel import apply_context_parallel
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
|
||||
if isinstance(config, ContextParallelConfig):
|
||||
config = ParallelConfig(context_parallel_config=config)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
device_type = torch._C._get_accelerator().type
|
||||
device_module = torch.get_device_module(device_type)
|
||||
device = torch.device(device_type, rank % device_module.device_count())
|
||||
|
||||
cp_mesh = None
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
|
||||
processor = module.processor
|
||||
if processor is None or not hasattr(processor, "_attention_backend"):
|
||||
continue
|
||||
|
||||
attention_backend = processor._attention_backend
|
||||
if attention_backend is None:
|
||||
attention_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
else:
|
||||
attention_backend = AttentionBackendName(attention_backend)
|
||||
|
||||
if not _AttentionBackendRegistry._is_context_parallel_available(attention_backend):
|
||||
compatible_backends = sorted(_AttentionBackendRegistry._supports_context_parallel)
|
||||
raise ValueError(
|
||||
f"Context parallelism is enabled but the attention processor '{processor.__class__.__name__}' "
|
||||
f"is using backend '{attention_backend.value}' which does not support context parallelism. "
|
||||
f"Please set a compatible attention backend: {compatible_backends} using `model.set_attention_backend()` before "
|
||||
f"calling `enable_parallelism()`."
|
||||
)
|
||||
|
||||
# All modules use the same attention processor and backend. We don't need to
|
||||
# iterate over all modules after checking the first processor
|
||||
break
|
||||
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
)
|
||||
|
||||
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
|
||||
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
config.setup(rank, world_size, device, mesh=mesh)
|
||||
self._parallel_config = config
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
if not isinstance(module, attention_classes):
|
||||
continue
|
||||
@@ -1545,6 +1557,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
continue
|
||||
processor._parallel_config = config
|
||||
|
||||
if config.context_parallel_config is not None:
|
||||
if cp_plan is None and self._cp_plan is None:
|
||||
raise ValueError(
|
||||
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
|
||||
)
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
apply_context_parallel(self, config.context_parallel_config, cp_plan)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user