Compare commits

..

3 Commits

Author SHA1 Message Date
DN6
4517b9311a update 2026-03-14 08:17:18 +05:30
DN6
7e3e640b5a update 2026-03-14 08:03:47 +05:30
DN6
7b961f07e7 update 2026-03-13 17:53:28 +05:30
6 changed files with 457 additions and 145 deletions

View File

@@ -0,0 +1,447 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import re
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
from dataclasses import dataclass, field
from ..utils import logging
from . import BaseDiffusersCLICommand
logger = logging.get_logger("diffusers-cli/daggr")
INTERNAL_TYPE_NAMES = {
"Tensor",
"Generator",
}
INTERNAL_TYPE_FULL_NAMES = {
"torch.Tensor",
"torch.Generator",
"torch.dtype",
}
SLIDER_PARAMS = {
"height": {"minimum": 256, "maximum": 2048, "step": 64},
"width": {"minimum": 256, "maximum": 2048, "step": 64},
"num_inference_steps": {"minimum": 1, "maximum": 100, "step": 1},
"guidance_scale": {"minimum": 0, "maximum": 30, "step": 0.5},
"strength": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_start": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_end": {"minimum": 0, "maximum": 1, "step": 0.05},
"controlnet_conditioning_scale": {"minimum": 0, "maximum": 2, "step": 0.1},
}
@dataclass
class BlockInfo:
name: str
class_name: str
description: str
inputs: list
outputs: list
user_inputs: list = field(default_factory=list)
port_connections: list = field(default_factory=list)
fixed_inputs: list = field(default_factory=list)
def daggr_command_factory(args: Namespace):
return DaggrCommand(
repo_id=args.repo_id,
output=args.output or "daggr_app.py",
workflow=getattr(args, "workflow", None),
trigger_inputs=getattr(args, "trigger_inputs", None),
)
class DaggrCommand(BaseDiffusersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
daggr_parser = parser.add_parser("daggr", help="Generate a daggr app from a modular pipeline repo.")
daggr_parser.add_argument(
"repo_id",
type=str,
help="HuggingFace Hub repo ID containing a modular pipeline (with modular_model_index.json).",
)
daggr_parser.add_argument(
"--output",
type=str,
default="daggr_app.py",
help="Output file path for the generated daggr app. Default: daggr_app.py",
)
daggr_parser.add_argument(
"--workflow",
type=str,
default=None,
help="Named workflow to resolve conditional blocks (e.g. 'text2image', 'image2image').",
)
daggr_parser.add_argument(
"--trigger-inputs",
nargs="*",
default=None,
help="Trigger input names for manual conditional resolution.",
)
daggr_parser.set_defaults(func=daggr_command_factory)
def __init__(
self,
repo_id: str,
output: str = "daggr_app.py",
workflow: str | None = None,
trigger_inputs: list | None = None,
):
self.repo_id = repo_id
self.output = output
self.workflow = workflow
self.trigger_inputs = trigger_inputs
def run(self):
from ..modular_pipelines.modular_pipeline import ModularPipelineBlocks
logger.info(f"Loading blocks from {self.repo_id}...")
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
blocks_class_name = blocks.__class__.__name__
if self.workflow:
logger.info(f"Resolving workflow: {self.workflow}")
exec_blocks = blocks.get_workflow(self.workflow)
elif self.trigger_inputs:
trigger_kwargs = {name: True for name in self.trigger_inputs}
logger.info(f"Resolving with trigger inputs: {self.trigger_inputs}")
exec_blocks = blocks.get_execution_blocks(**trigger_kwargs)
else:
logger.info("Resolving default execution blocks...")
exec_blocks = blocks.get_execution_blocks()
block_infos = _analyze_blocks(exec_blocks)
_classify_inputs(block_infos)
workflow_label = self.workflow or "default"
workflow_resolve_code = self._get_workflow_resolve_code()
code = _generate_code(block_infos, self.repo_id, blocks_class_name, workflow_label, workflow_resolve_code)
try:
ast.parse(code)
except SyntaxError as e:
logger.warning(f"Generated code has syntax error: {e}")
with open(self.output, "w") as f:
f.write(code)
logger.info(f"Daggr app written to {self.output}")
print(f"Generated daggr app: {self.output}")
print(f" Pipeline: {blocks_class_name}")
print(f" Workflow: {workflow_label}")
print(f" Blocks: {len(block_infos)}")
print(f"\nRun with: python {self.output}")
def _get_workflow_resolve_code(self):
if self.workflow:
return f"_pipeline._blocks.get_workflow({self.workflow!r})"
elif self.trigger_inputs:
kwargs_str = ", ".join(f"{name!r}: True" for name in self.trigger_inputs)
return f"_pipeline._blocks.get_execution_blocks(**{{{kwargs_str}}})"
else:
return "_pipeline._blocks.get_execution_blocks()"
def _analyze_blocks(exec_blocks):
block_infos = []
for name, block in exec_blocks.sub_blocks.items():
info = BlockInfo(
name=name,
class_name=block.__class__.__name__,
description=getattr(block, "description", "") or "",
inputs=list(block.inputs) if hasattr(block, "inputs") else [],
outputs=list(block.intermediate_outputs) if hasattr(block, "intermediate_outputs") else [],
)
block_infos.append(info)
return block_infos
def _get_type_name(type_hint):
if type_hint is None:
return None
if hasattr(type_hint, "__name__"):
return type_hint.__name__
if hasattr(type_hint, "__module__") and hasattr(type_hint, "__qualname__"):
return f"{type_hint.__module__}.{type_hint.__qualname__}"
return str(type_hint)
def _is_internal_type(type_hint):
if type_hint is None:
return True
type_name = _get_type_name(type_hint)
if type_name is None:
return True
if type_name in INTERNAL_TYPE_NAMES or type_name in INTERNAL_TYPE_FULL_NAMES:
return True
type_str = str(type_hint)
for full_name in INTERNAL_TYPE_FULL_NAMES:
if full_name in type_str:
return True
if type_str.startswith("dict[") or type_str == "dict":
return True
return False
def _type_hint_to_gradio(type_hint, param_name, default=None):
if _is_internal_type(type_hint):
return None
if param_name in SLIDER_PARAMS:
slider_opts = SLIDER_PARAMS[param_name]
val = default if default is not None else slider_opts.get("minimum", 0)
return (
f'gr.Slider(label="{param_name}", value={val!r}, '
f"minimum={slider_opts['minimum']}, maximum={slider_opts['maximum']}, "
f"step={slider_opts['step']})"
)
type_name = _get_type_name(type_hint)
type_str = str(type_hint)
if type_name == "str" or type_hint is str:
lines = 3 if "prompt" in param_name else 1
default_repr = f", value={default!r}" if default is not None else ""
return f'gr.Textbox(label="{param_name}", lines={lines}{default_repr})'
if type_name == "int" or type_hint is int:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}", precision=0{val})'
if type_name == "float" or type_hint is float:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}"{val})'
if type_name == "bool" or type_hint is bool:
val = default if default is not None else False
return f'gr.Checkbox(label="{param_name}", value={val!r})'
if "Image" in type_str:
if "list" in type_str.lower():
return f'gr.Gallery(label="{param_name}")'
return f'gr.Image(label="{param_name}")'
if default is not None:
return f'gr.Textbox(label="{param_name}", value={default!r})'
return f'gr.Textbox(label="{param_name}")'
def _output_type_to_gradio(type_hint, param_name):
if _is_internal_type(type_hint):
return None
type_str = str(type_hint)
if "Image" in type_str:
if "list" in type_str.lower():
return f'gr.Gallery(label="{param_name}")'
return f'gr.Image(label="{param_name}")'
if type_hint is str:
return f'gr.Textbox(label="{param_name}")'
if type_hint is int or type_hint is float:
return f'gr.Number(label="{param_name}")'
return None
def _classify_inputs(block_infos):
all_prior_outputs = {}
for info in block_infos:
user_inputs = []
port_connections = []
fixed_inputs = []
for inp in info.inputs:
if inp.name is None:
continue
if inp.name in all_prior_outputs:
port_connections.append((inp.name, all_prior_outputs[inp.name]))
elif _is_internal_type(inp.type_hint):
fixed_inputs.append(inp)
else:
user_inputs.append(inp)
info.user_inputs = user_inputs
info.port_connections = port_connections
info.fixed_inputs = fixed_inputs
for out in info.outputs:
if out.name and out.name not in all_prior_outputs:
all_prior_outputs[out.name] = info.name
def _sanitize_name(name):
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
if sanitized and sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized
def _generate_code(block_infos, repo_id, blocks_class_name, workflow_label, workflow_resolve_code):
lines = []
lines.append(f'"""Daggr app for {blocks_class_name} ({workflow_label} workflow)')
lines.append("Generated by: diffusers-cli daggr")
lines.append('"""')
lines.append("")
lines.append("import gradio as gr")
lines.append("from daggr import FnNode, InputNode, Graph")
lines.append("")
lines.append("")
# Pipeline and resolved blocks loader
lines.append("_pipeline = None")
lines.append("_exec_blocks = None")
lines.append("")
lines.append("")
lines.append("def _get_pipeline():")
lines.append(" global _pipeline, _exec_blocks")
lines.append(" if _pipeline is None:")
lines.append(" from diffusers import ModularPipeline")
lines.append(f" _pipeline = ModularPipeline.from_pretrained({repo_id!r}, trust_remote_code=True)")
lines.append(" _pipeline.load_components()")
lines.append(f" _exec_blocks = {workflow_resolve_code}")
lines.append(" return _pipeline, _exec_blocks")
lines.append("")
lines.append("")
# Wrapper functions
for info in block_infos:
fn_name = f"run_{_sanitize_name(info.name)}"
all_input_names = []
for inp in info.inputs:
if inp.name is not None:
all_input_names.append(inp.name)
params = ", ".join(all_input_names)
lines.append(f"def {fn_name}({params}):")
lines.append(" from diffusers.modular_pipelines.modular_pipeline import PipelineState")
lines.append("")
lines.append(" pipe, exec_blocks = _get_pipeline()")
lines.append(" state = PipelineState()")
for inp_name in all_input_names:
lines.append(f' state.set("{inp_name}", {inp_name})')
lines.append(f' block = exec_blocks.sub_blocks["{info.name}"]')
lines.append(" _, state = block(pipe, state)")
if len(info.outputs) == 0:
lines.append(" return None")
elif len(info.outputs) == 1:
out = info.outputs[0]
lines.append(f' return state.get("{out.name}")')
else:
out_names = [out.name for out in info.outputs]
out_dict = ", ".join(f'"{n}": state.get("{n}")' for n in out_names)
lines.append(f" return {{{out_dict}}}")
lines.append("")
lines.append("")
# Collect all user-facing inputs across blocks
all_user_inputs = OrderedDict()
for info in block_infos:
for inp in info.user_inputs:
if inp.name not in all_user_inputs:
all_user_inputs[inp.name] = inp
# InputNode
if all_user_inputs:
lines.append("# -- User Inputs --")
lines.append('user_inputs = InputNode("User Inputs", ports={')
for inp_name, inp in all_user_inputs.items():
gradio_comp = _type_hint_to_gradio(inp.type_hint, inp_name, inp.default)
if gradio_comp:
lines.append(f' "{inp_name}": {gradio_comp},')
lines.append("})")
lines.append("")
lines.append("")
# FnNode definitions
lines.append("# -- Pipeline Blocks --")
node_var_names = {}
for info in block_infos:
var_name = f"{_sanitize_name(info.name)}_node"
node_var_names[info.name] = var_name
fn_name = f"run_{_sanitize_name(info.name)}"
display_name = info.name.replace("_", " ").replace(".", " > ").title()
# Build inputs dict
input_entries = []
for inp in info.inputs:
if inp.name is None:
continue
connected = False
for conn_name, source_block in info.port_connections:
if conn_name == inp.name:
source_var = node_var_names[source_block]
input_entries.append(f' "{inp.name}": {source_var}.{inp.name},')
connected = True
break
if not connected:
if inp.name in all_user_inputs:
input_entries.append(f' "{inp.name}": user_inputs.{inp.name},')
elif inp.default is not None:
input_entries.append(f' "{inp.name}": {inp.default!r},')
else:
input_entries.append(f' "{inp.name}": None,')
# Build outputs dict
output_entries = []
for out in info.outputs:
gradio_out = _output_type_to_gradio(out.type_hint, out.name)
if gradio_out:
output_entries.append(f' "{out.name}": {gradio_out},')
else:
output_entries.append(f' "{out.name}": None,')
lines.append(f"{var_name} = FnNode(")
lines.append(f" fn={fn_name},")
lines.append(f' name="{display_name}",')
if input_entries:
lines.append(" inputs={")
lines.extend(input_entries)
lines.append(" },")
if output_entries:
lines.append(" outputs={")
lines.extend(output_entries)
lines.append(" },")
lines.append(")")
lines.append("")
# Graph
lines.append("")
lines.append("# -- Graph --")
all_node_vars = []
if all_user_inputs:
all_node_vars.append("user_inputs")
all_node_vars.extend(node_var_names[info.name] for info in block_infos)
graph_name = f"{blocks_class_name} - {workflow_label}"
nodes_str = ", ".join(all_node_vars)
lines.append(f'graph = Graph("{graph_name}", nodes=[{nodes_str}])')
lines.append("graph.launch()")
lines.append("")
return "\n".join(lines)

View File

@@ -16,6 +16,7 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .daggr_app import DaggrCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -28,6 +29,7 @@ def main():
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
CustomBlocksCommand.register_subcommand(commands_parser)
DaggrCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@@ -60,16 +60,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -78,7 +68,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -135,7 +124,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -14,7 +14,6 @@
import importlib
import inspect
import os
import shutil
import sys
import traceback
import warnings
@@ -1884,36 +1883,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def _maybe_save_custom_code(self, save_directory: str | os.PathLike):
"""Save custom code files (blocks config and Python modules) to the save directory."""
if self._blocks is None:
return
blocks_module = type(self._blocks).__module__
is_custom_code = not blocks_module.startswith("diffusers.") and blocks_module != "diffusers"
if not is_custom_code:
return
os.makedirs(save_directory, exist_ok=True)
self._blocks.save_pretrained(save_directory)
source_file = inspect.getfile(type(self._blocks))
module_file = os.path.basename(source_file)
dest_file = os.path.join(save_directory, module_file)
if os.path.abspath(source_file) != os.path.abspath(dest_file):
shutil.copyfile(source_file, dest_file)
from ..utils.dynamic_modules_utils import get_relative_import_files
for rel_file in get_relative_import_files(source_file):
rel_name = os.path.relpath(rel_file, os.path.dirname(source_file))
rel_dest = os.path.join(save_directory, rel_name)
if os.path.abspath(rel_file) != os.path.abspath(rel_dest):
os.makedirs(os.path.dirname(rel_dest), exist_ok=True)
shutil.copyfile(rel_file, rel_dest)
def save_pretrained(
self,
save_directory: str | os.PathLike,
@@ -2029,8 +1998,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self._maybe_save_custom_code(save_directory)
self.save_config(save_directory=save_directory)
if push_to_hub:

View File

@@ -60,7 +60,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -84,59 +89,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -174,48 +126,3 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)