Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
6b127364c4 up 2026-01-23 17:35:26 +05:30
4 changed files with 6 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
@@ -37,7 +37,7 @@ RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
--index-url https://download.pytorch.org/whl/cu128
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -1,4 +1,4 @@
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
@@ -37,7 +37,7 @@ RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
--index-url https://download.pytorch.org/whl/cu128
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -59,12 +59,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -73,7 +67,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None
_rank: int = None
_world_size: int = None
@@ -122,7 +115,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1569,7 +1569,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,