Compare commits

..

2 Commits

Author SHA1 Message Date
DN6
eff791831f update 2026-03-13 10:28:38 +05:30
Dhruv Nair
07c5ba8eee [Context Parallel] Add support for custom device mesh (#13064)
* add custom mesh support

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-11 16:42:11 +05:30
7 changed files with 145 additions and 737 deletions

View File

@@ -2,7 +2,6 @@
line-length = 119
extend-exclude = [
"src/diffusers/pipelines/flux2/system_messages.py",
"src/diffusers/commands/daggr_app.py",
]
[tool.ruff.lint]

View File

@@ -1,726 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import re
import tempfile
import types
import typing
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass, field
from ..utils import logging
from . import BaseDiffusersCLICommand
logger = logging.get_logger("diffusers-cli/daggr")
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
INTERNAL_TYPE_NAMES = {
"Tensor",
"Generator",
}
INTERNAL_TYPE_FULL_NAMES = {
"torch.Tensor",
"torch.Generator",
"torch.dtype",
}
SLIDER_PARAMS = {
"height": {"minimum": 256, "maximum": 2048, "step": 64},
"width": {"minimum": 256, "maximum": 2048, "step": 64},
"num_inference_steps": {"minimum": 1, "maximum": 100, "step": 1},
"guidance_scale": {"minimum": 0, "maximum": 30, "step": 0.5},
"strength": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_start": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_end": {"minimum": 0, "maximum": 1, "step": 0.05},
"controlnet_conditioning_scale": {"minimum": 0, "maximum": 2, "step": 0.1},
}
DROPDOWN_PARAMS = {
"output_type": {"choices": ["pil", "np", "pt", "latent"], "value": "pil"},
}
DEFAULT_REQUIREMENTS = (
"torch --index-url https://download.pytorch.org/whl/cu129\n"
"diffusers\n"
"transformers\n"
"accelerate\n"
"sentencepiece\n"
"bitsandbytes\n"
"daggr\n"
"gradio\n"
)
# ---------------------------------------------------------------------------
# Code templates (string concat so formatters can't mangle newlines)
# ---------------------------------------------------------------------------
_HEADER_TEMPLATE = (
"import os\n"
"{extra_imports}\n"
"import gradio as gr\n"
"from daggr import FnNode, InputNode, Graph\n"
"\n"
"\n"
"_pipeline = None\n"
"_exec_blocks = None\n"
"_tensor_store = {{}}\n"
"\n"
"\n"
"def _get_pipeline():\n"
" global _pipeline, _exec_blocks\n"
" if _pipeline is None:\n"
" from diffusers import ModularPipeline\n"
" import torch\n"
"\n"
' _token = os.environ.get("HF_TOKEN")\n'
' _device = "cuda" if torch.cuda.is_available() else "cpu"\n'
" _pipeline = ModularPipeline.from_pretrained({repo_id}, trust_remote_code=True, token=_token)\n"
" _pipeline.load_components(torch_dtype=torch.bfloat16, device_map=_device)\n"
" _exec_blocks = {workflow_resolve_code}\n"
" return _pipeline, _exec_blocks"
)
_SAVE_IMAGE_TEMPLATE = (
"\n"
"\n"
"def _save_image(val):\n"
" from PIL import Image as PILImage\n"
"\n"
" if isinstance(val, list):\n"
" paths = [_save_image(item) for item in val]\n"
" return paths[0] if paths else None\n"
" if isinstance(val, PILImage.Image):\n"
' f = tempfile.NamedTemporaryFile(suffix=".png", delete=False)\n'
" val.save(f.name)\n"
" return f.name\n"
" return val"
)
_CACHE_DISABLE_TEMPLATE = (
"\n"
"# Disable result caching so every run executes all nodes fresh\n"
"from daggr.state import SessionState\n"
"SessionState.get_latest_result = lambda *a, **kw: None\n"
"SessionState.get_result_by_index = lambda *a, **kw: None\n"
"SessionState.save_result = lambda *a, **kw: None\n"
)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class BlockInfo:
name: str
class_name: str
description: str
inputs: list
outputs: list
user_inputs: list = field(default_factory=list)
port_connections: list = field(default_factory=list)
fixed_inputs: list = field(default_factory=list)
sub_block_names: list = field(default_factory=list)
# ---------------------------------------------------------------------------
# CLI command
# ---------------------------------------------------------------------------
def daggr_command_factory(args: Namespace):
return DaggrCommand(
repo_id=args.repo_id,
output=args.output,
workflow=getattr(args, "workflow", None),
trigger_inputs=getattr(args, "trigger_inputs", None),
deploy=getattr(args, "deploy", None),
hardware=getattr(args, "hardware", "cpu-basic"),
private=getattr(args, "private", False),
requirements=getattr(args, "requirements", None),
)
class DaggrCommand(BaseDiffusersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
daggr_parser = parser.add_parser("daggr", help="Generate a daggr app from a modular pipeline repo.")
daggr_parser.add_argument(
"repo_id",
type=str,
help="HuggingFace Hub repo ID containing a modular pipeline.",
)
daggr_parser.add_argument(
"--output",
type=str,
default=None,
help="Save the generated app to a file instead of launching it directly.",
)
daggr_parser.add_argument(
"--workflow",
type=str,
default=None,
help="Named workflow to resolve conditional blocks (e.g. 'text2image', 'image2image').",
)
daggr_parser.add_argument(
"--trigger-inputs",
nargs="*",
default=None,
help="Trigger input names for manual conditional resolution.",
)
daggr_parser.add_argument(
"--deploy",
type=str,
default=None,
metavar="SPACE_NAME",
help="Deploy the generated app to a HuggingFace Space via daggr deploy.",
)
daggr_parser.add_argument(
"--hardware",
type=str,
default="cpu-basic",
help="Hardware tier for the deployed Space (default: cpu-basic). E.g. a10g-small, a100-large.",
)
daggr_parser.add_argument(
"--private",
action="store_true",
default=False,
help="Make the deployed Space private.",
)
daggr_parser.add_argument(
"--requirements",
type=str,
default=None,
help="Path to a requirements.txt file for the deployed Space.",
)
daggr_parser.set_defaults(func=daggr_command_factory)
def __init__(
self,
repo_id: str,
output: str | None = None,
workflow: str | None = None,
trigger_inputs: list | None = None,
deploy: str | None = None,
hardware: str = "cpu-basic",
private: bool = False,
requirements: str | None = None,
):
self.repo_id = repo_id
self.output = output
self.workflow = workflow
self.trigger_inputs = trigger_inputs
self.deploy = deploy
self.hardware = hardware
self.private = private
self.requirements = requirements
def run(self):
from ..modular_pipelines.modular_pipeline import ModularPipeline
logger.info(f"Loading blocks from {self.repo_id}...")
pipeline = ModularPipeline.from_pretrained(self.repo_id, trust_remote_code=True)
blocks = pipeline._blocks
blocks_class_name = blocks.__class__.__name__
if self.workflow:
logger.info(f"Resolving workflow: {self.workflow}")
exec_blocks = blocks.get_workflow(self.workflow)
elif self.trigger_inputs:
trigger_kwargs = {name: True for name in self.trigger_inputs}
logger.info(f"Resolving with trigger inputs: {self.trigger_inputs}")
exec_blocks = blocks.get_execution_blocks(**trigger_kwargs)
else:
logger.info("Resolving default execution blocks...")
exec_blocks = blocks.get_execution_blocks()
block_infos = _get_block_info(exec_blocks)
_filter_outputs(block_infos)
_classify_inputs(block_infos)
workflow_label = self.workflow or "default"
workflow_resolve_code = self._get_workflow_resolve_code()
code = _generate_code(block_infos, self.repo_id, blocks_class_name, workflow_label, workflow_resolve_code)
try:
ast.parse(code)
except SyntaxError as e:
logger.warning(f"Generated code has syntax error: {e}")
if self.deploy:
self._deploy_to_space(code)
elif self.output:
with open(self.output, "w") as f:
f.write(code)
print(f"Generated daggr app: {self.output}")
print(f" Pipeline: {blocks_class_name}")
print(f" Workflow: {workflow_label}")
print(f" Blocks: {len(block_infos)}")
print(f"\nRun with: python {self.output}")
else:
print(f"Launching daggr app for {blocks_class_name} ({workflow_label} workflow)...")
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, prefix="daggr_")
tmp.write(code)
tmp.close()
logger.info(f"Generated temp script: {tmp.name}")
exec(compile(code, tmp.name, "exec"), {"__name__": "__main__"})
def _deploy_to_space(self, code):
import os
from pathlib import Path
from daggr.cli import _deploy
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, prefix="daggr_")
tmp.write(code)
tmp.close()
req_path = self.requirements
if not req_path:
req_file = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, prefix="daggr_req_")
req_file.write(DEFAULT_REQUIREMENTS)
req_file.close()
req_path = req_file.name
secrets = {}
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
secrets["HF_TOKEN"] = hf_token
deploy_name = self.deploy
deploy_org = None
if "/" in deploy_name:
deploy_org, deploy_name = deploy_name.rsplit("/", 1)
_deploy(
script_path=Path(tmp.name),
name=deploy_name,
title=None,
org=deploy_org,
private=self.private,
hardware=self.hardware,
secrets=secrets,
requirements_path=req_path,
dry_run=False,
)
def _get_workflow_resolve_code(self):
if self.workflow:
return f"_pipeline._blocks.get_workflow({self.workflow!r})"
elif self.trigger_inputs:
kwargs_str = ", ".join(f"{name!r}: True" for name in self.trigger_inputs)
return f"_pipeline._blocks.get_execution_blocks(**{{{kwargs_str}}})"
else:
return "_pipeline._blocks.get_execution_blocks()"
# ---------------------------------------------------------------------------
# Block analysis
# ---------------------------------------------------------------------------
def _get_block_info(exec_blocks):
block_infos = []
for name, block in exec_blocks.sub_blocks.items():
block_infos.append(
BlockInfo(
name=name,
class_name=block.__class__.__name__,
description=getattr(block, "description", "") or "",
inputs=list(block.inputs) if hasattr(block, "inputs") else [],
outputs=list(block.intermediate_outputs) if hasattr(block, "intermediate_outputs") else [],
)
)
return block_infos
def _filter_outputs(block_infos):
last_idx = len(block_infos) - 1
for i, info in enumerate(block_infos):
downstream_input_names = set()
for later_info in block_infos[i + 1 :]:
for inp in later_info.inputs:
if inp.name:
downstream_input_names.add(inp.name)
is_last = i == last_idx
info.outputs = [
out
for out in info.outputs
if out.name in downstream_input_names
or (is_last and (_contains_pil_image(out.type_hint) or not _is_internal_type(out.type_hint)))
]
def _classify_inputs(block_infos):
all_prior_outputs = {}
for info in block_infos:
user_inputs, port_connections, fixed_inputs = [], [], []
for inp in info.inputs:
if inp.name is None:
continue
if inp.name in all_prior_outputs:
port_connections.append((inp.name, all_prior_outputs[inp.name]))
elif _is_user_facing(inp):
user_inputs.append(inp)
else:
fixed_inputs.append(inp)
info.user_inputs = user_inputs
info.port_connections = port_connections
info.fixed_inputs = fixed_inputs
for out in info.outputs:
if out.name:
all_prior_outputs[out.name] = info.name
# ---------------------------------------------------------------------------
# Type helpers
# ---------------------------------------------------------------------------
def _get_type_name(type_hint):
if type_hint is None:
return None
if hasattr(type_hint, "__name__"):
return type_hint.__name__
if hasattr(type_hint, "__module__") and hasattr(type_hint, "__qualname__"):
return f"{type_hint.__module__}.{type_hint.__qualname__}"
return str(type_hint)
def _is_internal_type(type_hint):
if type_hint is None:
return False
type_name = _get_type_name(type_hint)
if type_name is None:
return False
if type_name in INTERNAL_TYPE_NAMES or type_name in INTERNAL_TYPE_FULL_NAMES:
return True
type_str = str(type_hint)
for full_name in INTERNAL_TYPE_FULL_NAMES:
if full_name in type_str:
return True
if type_str.startswith("dict[") or type_str == "dict":
return True
return False
def _contains_pil_image(type_hint):
from PIL import Image
if type_hint is Image.Image:
return True
args = typing.get_args(type_hint)
return any(_contains_pil_image(a) for a in args) if args else False
def _is_list_or_tuple_type(type_hint):
origin = typing.get_origin(type_hint)
if origin in (list, tuple):
return True
if origin is typing.Union or origin is types.UnionType:
return any(_is_list_or_tuple_type(a) for a in typing.get_args(type_hint))
return False
def _resolve_from_template(inp):
from ..modular_pipelines.modular_pipeline_utils import INPUT_PARAM_TEMPLATES
type_hint = inp.type_hint
default = inp.default
if inp.name in INPUT_PARAM_TEMPLATES:
tmpl = INPUT_PARAM_TEMPLATES[inp.name]
if type_hint is None:
type_hint = tmpl.get("type_hint")
if default is None:
default = tmpl.get("default")
return type_hint, default
def _is_user_facing(inp):
type_hint, _ = _resolve_from_template(inp)
if type_hint is not None:
if _contains_pil_image(type_hint):
return True
return not _is_internal_type(type_hint)
if inp.name in SLIDER_PARAMS:
return True
return False
def _sanitize_name(name):
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
if sanitized and sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized
# ---------------------------------------------------------------------------
# Gradio component code strings
# ---------------------------------------------------------------------------
def _type_hint_to_gradio(type_hint, param_name, default=None):
if param_name in DROPDOWN_PARAMS:
opts = DROPDOWN_PARAMS[param_name]
val = default if default is not None else opts.get("value")
return f'gr.Dropdown(label="{param_name}", choices={opts["choices"]!r}, value={val!r})'
if param_name in SLIDER_PARAMS:
opts = SLIDER_PARAMS[param_name]
val = default if default is not None else opts.get("minimum", 0)
return (
f'gr.Slider(label="{param_name}", value={val!r}, '
f"minimum={opts['minimum']}, maximum={opts['maximum']}, step={opts['step']})"
)
if type_hint is not None and _is_internal_type(type_hint):
return None
if type_hint is not None and _contains_pil_image(type_hint):
return f'gr.Image(label="{param_name}")'
if type_hint is str:
default_repr = f", value={default!r}" if default is not None else ""
return f'gr.Textbox(label="{param_name}", lines=1{default_repr})'
if type_hint is int:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}", precision=0{val})'
if type_hint is float:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}"{val})'
if type_hint is bool:
val = default if default is not None else False
return f'gr.Checkbox(label="{param_name}", value={val!r})'
if default is not None:
return f'gr.Textbox(label="{param_name}", value={default!r})'
return f'gr.Textbox(label="{param_name}")'
def _output_type_to_gradio(type_hint, param_name):
if type_hint is not None and _contains_pil_image(type_hint):
return f'gr.Image(label="{param_name}")'
if type_hint is str:
return f'gr.Textbox(label="{param_name}")'
if type_hint is int or type_hint is float:
return f'gr.Number(label="{param_name}")'
return f'gr.Textbox(label="{param_name}", visible=False)'
# ---------------------------------------------------------------------------
# Code generation
# ---------------------------------------------------------------------------
def _generate_code(block_infos, repo_id, blocks_class_name, workflow_label, workflow_resolve_code):
sections = []
# Pre-compute metadata
blocks_with_image_inputs = {}
blocks_with_parseable_inputs = {}
blocks_with_image_outputs = set()
port_conn_source = {}
for info in block_infos:
img_names, parse_names = set(), set()
for inp in info.inputs:
if inp.name is None:
continue
resolved_type, _ = _resolve_from_template(inp)
if resolved_type is not None and _contains_pil_image(resolved_type):
img_names.add(inp.name)
elif resolved_type is not None and _is_list_or_tuple_type(resolved_type):
parse_names.add(inp.name)
if img_names:
blocks_with_image_inputs[info.name] = img_names
if parse_names:
blocks_with_parseable_inputs[info.name] = parse_names
if any(out.type_hint is not None and _contains_pil_image(out.type_hint) for out in info.outputs):
blocks_with_image_outputs.add(info.name)
for conn_name, source_block in info.port_connections:
port_conn_source[(info.name, conn_name)] = source_block
needs_image_io = blocks_with_image_inputs or blocks_with_image_outputs
needs_parsing = bool(blocks_with_parseable_inputs)
# -- Header & helpers --
extra_imports = ""
if needs_parsing:
extra_imports += "import ast\n"
if needs_image_io:
extra_imports += "import tempfile\n"
sections.append(f'"""Daggr app for {blocks_class_name} ({workflow_label} workflow)')
sections.append("Generated by: diffusers-cli daggr")
sections.append('"""')
header = _HEADER_TEMPLATE.format(
extra_imports=extra_imports, repo_id=repr(repo_id), workflow_resolve_code=workflow_resolve_code
)
sections.extend(header.splitlines())
if needs_image_io:
sections.extend(_SAVE_IMAGE_TEMPLATE.splitlines())
# -- Block functions --
for info in block_infos:
fn_name = f"run_{_sanitize_name(info.name)}"
input_names = [inp.name for inp in info.inputs if inp.name is not None]
port_conn_names = {c for c, _ in info.port_connections}
has_image_out = info.name in blocks_with_image_outputs
body_lines = []
body_lines.append(" from diffusers.modular_pipelines.modular_pipeline import PipelineState")
body_lines.append("")
body_lines.append(" pipe, exec_blocks = _get_pipeline()")
body_lines.append(" state = PipelineState()")
if info.name in blocks_with_image_inputs:
body_lines.append(" from PIL import Image as PILImage")
body_lines.append("")
for img_name in blocks_with_image_inputs[info.name]:
body_lines.append(f" if {img_name} is not None and isinstance({img_name}, str):")
body_lines.append(f" {img_name} = PILImage.open({img_name})")
if info.name in blocks_with_parseable_inputs:
for parse_name in blocks_with_parseable_inputs[info.name]:
body_lines.append(f" if {parse_name} is not None and isinstance({parse_name}, str):")
body_lines.append(
f" {parse_name} = ast.literal_eval({parse_name}.strip()) if {parse_name}.strip() else None"
)
for n in input_names:
if n in port_conn_names:
source = port_conn_source[(info.name, n)]
body_lines.append(f' state.set("{n}", _tensor_store.get("{source}:{n}", {n}))')
else:
body_lines.append(f' state.set("{n}", {n})')
if info.sub_block_names:
for n in info.sub_block_names:
body_lines.append(f' _, state = exec_blocks.sub_blocks["{n}"](pipe, state)')
else:
body_lines.append(f' _, state = exec_blocks.sub_blocks["{info.name}"](pipe, state)')
for o in info.outputs:
if o.type_hint is not None and _is_internal_type(o.type_hint):
body_lines.append(f' _tensor_store["{info.name}:{o.name}"] = state.get("{o.name}")')
if len(info.outputs) == 0:
body_lines.append(" return None")
elif len(info.outputs) == 1:
out = info.outputs[0]
if has_image_out and _contains_pil_image(out.type_hint):
body_lines.append(f' return _save_image(state.get("{out.name}"))')
elif out.type_hint is not None and _is_internal_type(out.type_hint):
body_lines.append(f' return "{info.name}:{out.name}"')
else:
body_lines.append(f' return state.get("{out.name}")')
else:
parts = []
for o in info.outputs:
if has_image_out and o.type_hint is not None and _contains_pil_image(o.type_hint):
parts.append(f'_save_image(state.get("{o.name}"))')
elif o.type_hint is not None and _is_internal_type(o.type_hint):
parts.append(f'"{info.name}:{o.name}"')
else:
parts.append(f'state.get("{o.name}")')
body_lines.append(f" return {', '.join(parts)}")
body = "\n".join(body_lines)
sections.append(f"\n\ndef {fn_name}({', '.join(input_names)}):\n{body}\n")
# -- Node definitions --
sections.append("\n\n# -- Pipeline Blocks --")
node_var_names = {}
input_node_var_names = []
user_input_sources = {}
for info in block_infos:
var_name = f"{_sanitize_name(info.name)}_node"
node_var_names[info.name] = var_name
fn_name = f"run_{_sanitize_name(info.name)}"
display_name = info.name.replace("_", " ").replace(".", " > ").title()
for inp in info.user_inputs:
if inp.name in user_input_sources:
continue
resolved_type, resolved_default = _resolve_from_template(inp)
gradio_comp = _type_hint_to_gradio(resolved_type, inp.name, resolved_default)
if gradio_comp:
input_var = f"{_sanitize_name(inp.name)}_input"
sections.append(
f'\n{input_var} = InputNode("{inp.name}", ports={{\n "{inp.name}": {gradio_comp},\n}})'
)
input_node_var_names.append(input_var)
user_input_sources[inp.name] = input_var
input_entries = []
for inp in info.inputs:
if inp.name is None:
continue
connected = False
for conn_name, source_block in info.port_connections:
if conn_name == inp.name:
input_entries.append(f' "{inp.name}": {node_var_names[source_block]}.{inp.name},')
connected = True
break
if not connected:
if inp.name in user_input_sources:
input_entries.append(f' "{inp.name}": {user_input_sources[inp.name]}.{inp.name},')
elif inp.default is not None:
input_entries.append(f' "{inp.name}": {inp.default!r},')
else:
input_entries.append(f' "{inp.name}": None,')
output_entries = []
for out in info.outputs:
output_entries.append(f' "{out.name}": {_output_type_to_gradio(out.type_hint, out.name)},')
node_parts = [f"\n{var_name} = FnNode(", f" fn={fn_name},", f' name="{display_name}",']
if input_entries:
node_parts.append(" inputs={")
node_parts.extend(input_entries)
node_parts.append(" },")
if output_entries:
node_parts.append(" outputs={")
node_parts.extend(output_entries)
node_parts.append(" },")
node_parts.append(")")
sections.append("\n".join(node_parts))
# -- Graph launch with cache disabled --
all_node_vars = input_node_var_names + [node_var_names[info.name] for info in block_infos]
graph_name = f"{blocks_class_name} - {workflow_label}"
nodes_str = ", ".join(all_node_vars)
sections.append("")
sections.append("")
sections.append("# -- Graph --")
sections.extend(_CACHE_DISABLE_TEMPLATE.splitlines())
sections.append("")
sections.append(f'graph = Graph("{graph_name}", nodes=[{nodes_str}])')
sections.append("graph.launch()")
sections.append("")
return "\n".join(sections)

View File

@@ -16,7 +16,6 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .daggr_app import DaggrCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -29,7 +28,6 @@ def main():
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
CustomBlocksCommand.register_subcommand(commands_parser)
DaggrCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@@ -60,6 +60,16 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -68,6 +78,7 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -124,7 +135,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh._flatten()
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -14,6 +14,7 @@
import importlib
import inspect
import os
import shutil
import sys
import traceback
import warnings
@@ -1883,6 +1884,36 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def _maybe_save_custom_code(self, save_directory: str | os.PathLike):
"""Save custom code files (blocks config and Python modules) to the save directory."""
if self._blocks is None:
return
blocks_module = type(self._blocks).__module__
is_custom_code = not blocks_module.startswith("diffusers.") and blocks_module != "diffusers"
if not is_custom_code:
return
os.makedirs(save_directory, exist_ok=True)
self._blocks.save_pretrained(save_directory)
source_file = inspect.getfile(type(self._blocks))
module_file = os.path.basename(source_file)
dest_file = os.path.join(save_directory, module_file)
if os.path.abspath(source_file) != os.path.abspath(dest_file):
shutil.copyfile(source_file, dest_file)
from ..utils.dynamic_modules_utils import get_relative_import_files
for rel_file in get_relative_import_files(source_file):
rel_name = os.path.relpath(rel_file, os.path.dirname(source_file))
rel_dest = os.path.join(save_directory, rel_name)
if os.path.abspath(rel_file) != os.path.abspath(rel_dest):
os.makedirs(os.path.dirname(rel_dest), exist_ok=True)
shutil.copyfile(rel_file, rel_dest)
def save_pretrained(
self,
save_directory: str | os.PathLike,
@@ -1998,6 +2029,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self._maybe_save_custom_code(save_directory)
self.save_config(save_directory=save_directory)
if push_to_hub:

View File

@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)