mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-04 01:45:15 +08:00
Compare commits
1 Commits
custom-dev
...
cuda-128-g
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6b127364c4 |
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
@@ -37,7 +37,7 @@ RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
|
||||
FROM nvidia/cuda:12.8.0-runtime-ubuntu22.04
|
||||
LABEL maintainer="Hugging Face"
|
||||
LABEL repository="diffusers"
|
||||
|
||||
@@ -37,7 +37,7 @@ RUN uv pip install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
torchaudio \
|
||||
--index-url https://download.pytorch.org/whl/cu121
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
|
||||
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"
|
||||
|
||||
|
||||
@@ -59,12 +59,6 @@ class ContextParallelConfig:
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
|
||||
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
|
||||
creating a new one. This is useful when combining context parallelism with other parallelism strategies
|
||||
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
|
||||
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
|
||||
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
|
||||
|
||||
"""
|
||||
|
||||
@@ -73,7 +67,6 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None
|
||||
|
||||
_rank: int = None
|
||||
_world_size: int = None
|
||||
@@ -122,7 +115,7 @@ class ContextParallelConfig:
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
|
||||
@@ -1569,7 +1569,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
|
||||
Reference in New Issue
Block a user