Compare commits

..

3 Commits

Author SHA1 Message Date
DN6
4517b9311a update 2026-03-14 08:17:18 +05:30
DN6
7e3e640b5a update 2026-03-14 08:03:47 +05:30
DN6
7b961f07e7 update 2026-03-13 17:53:28 +05:30
13 changed files with 592 additions and 495 deletions

View File

@@ -532,6 +532,8 @@
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/ddim
title: DDIM
- local: api/pipelines/ddpm
@@ -675,8 +677,6 @@
title: CogVideoX
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/helios

View File

@@ -21,31 +21,29 @@
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
## Basic usage
## Loading original format checkpoints
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
```python
import torch
from diffusers import Cosmos2_5_PredictBasePipeline
from diffusers.utils import export_to_video
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
model_id = "nvidia/Cosmos-Predict2.5-2B"
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
)
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
transformer = CosmosTransformer3DModel.from_single_file(
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
torch_dtype=torch.bfloat16,
).to("cuda")
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
output = pipe(
image=None,
video=None,
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=93,
generator=torch.Generator().manual_seed(1),
).frames[0]
export_to_video(output, "text2world.mp4", fps=16)
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
).images[0]
output.save("output.png")
```
## Cosmos2_5_TransferPipeline

View File

@@ -44,7 +44,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
| [ControlNet-XS](controlnetxs) | text2image |
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
| [Cosmos](cosmos) | text2video, video2video |
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
| [DDIM](ddim) | unconditional image generation |
| [DDPM](ddpm) | unconditional image generation |

View File

@@ -0,0 +1,447 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import re
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
from dataclasses import dataclass, field
from ..utils import logging
from . import BaseDiffusersCLICommand
logger = logging.get_logger("diffusers-cli/daggr")
INTERNAL_TYPE_NAMES = {
"Tensor",
"Generator",
}
INTERNAL_TYPE_FULL_NAMES = {
"torch.Tensor",
"torch.Generator",
"torch.dtype",
}
SLIDER_PARAMS = {
"height": {"minimum": 256, "maximum": 2048, "step": 64},
"width": {"minimum": 256, "maximum": 2048, "step": 64},
"num_inference_steps": {"minimum": 1, "maximum": 100, "step": 1},
"guidance_scale": {"minimum": 0, "maximum": 30, "step": 0.5},
"strength": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_start": {"minimum": 0, "maximum": 1, "step": 0.05},
"control_guidance_end": {"minimum": 0, "maximum": 1, "step": 0.05},
"controlnet_conditioning_scale": {"minimum": 0, "maximum": 2, "step": 0.1},
}
@dataclass
class BlockInfo:
name: str
class_name: str
description: str
inputs: list
outputs: list
user_inputs: list = field(default_factory=list)
port_connections: list = field(default_factory=list)
fixed_inputs: list = field(default_factory=list)
def daggr_command_factory(args: Namespace):
return DaggrCommand(
repo_id=args.repo_id,
output=args.output or "daggr_app.py",
workflow=getattr(args, "workflow", None),
trigger_inputs=getattr(args, "trigger_inputs", None),
)
class DaggrCommand(BaseDiffusersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
daggr_parser = parser.add_parser("daggr", help="Generate a daggr app from a modular pipeline repo.")
daggr_parser.add_argument(
"repo_id",
type=str,
help="HuggingFace Hub repo ID containing a modular pipeline (with modular_model_index.json).",
)
daggr_parser.add_argument(
"--output",
type=str,
default="daggr_app.py",
help="Output file path for the generated daggr app. Default: daggr_app.py",
)
daggr_parser.add_argument(
"--workflow",
type=str,
default=None,
help="Named workflow to resolve conditional blocks (e.g. 'text2image', 'image2image').",
)
daggr_parser.add_argument(
"--trigger-inputs",
nargs="*",
default=None,
help="Trigger input names for manual conditional resolution.",
)
daggr_parser.set_defaults(func=daggr_command_factory)
def __init__(
self,
repo_id: str,
output: str = "daggr_app.py",
workflow: str | None = None,
trigger_inputs: list | None = None,
):
self.repo_id = repo_id
self.output = output
self.workflow = workflow
self.trigger_inputs = trigger_inputs
def run(self):
from ..modular_pipelines.modular_pipeline import ModularPipelineBlocks
logger.info(f"Loading blocks from {self.repo_id}...")
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
blocks_class_name = blocks.__class__.__name__
if self.workflow:
logger.info(f"Resolving workflow: {self.workflow}")
exec_blocks = blocks.get_workflow(self.workflow)
elif self.trigger_inputs:
trigger_kwargs = {name: True for name in self.trigger_inputs}
logger.info(f"Resolving with trigger inputs: {self.trigger_inputs}")
exec_blocks = blocks.get_execution_blocks(**trigger_kwargs)
else:
logger.info("Resolving default execution blocks...")
exec_blocks = blocks.get_execution_blocks()
block_infos = _analyze_blocks(exec_blocks)
_classify_inputs(block_infos)
workflow_label = self.workflow or "default"
workflow_resolve_code = self._get_workflow_resolve_code()
code = _generate_code(block_infos, self.repo_id, blocks_class_name, workflow_label, workflow_resolve_code)
try:
ast.parse(code)
except SyntaxError as e:
logger.warning(f"Generated code has syntax error: {e}")
with open(self.output, "w") as f:
f.write(code)
logger.info(f"Daggr app written to {self.output}")
print(f"Generated daggr app: {self.output}")
print(f" Pipeline: {blocks_class_name}")
print(f" Workflow: {workflow_label}")
print(f" Blocks: {len(block_infos)}")
print(f"\nRun with: python {self.output}")
def _get_workflow_resolve_code(self):
if self.workflow:
return f"_pipeline._blocks.get_workflow({self.workflow!r})"
elif self.trigger_inputs:
kwargs_str = ", ".join(f"{name!r}: True" for name in self.trigger_inputs)
return f"_pipeline._blocks.get_execution_blocks(**{{{kwargs_str}}})"
else:
return "_pipeline._blocks.get_execution_blocks()"
def _analyze_blocks(exec_blocks):
block_infos = []
for name, block in exec_blocks.sub_blocks.items():
info = BlockInfo(
name=name,
class_name=block.__class__.__name__,
description=getattr(block, "description", "") or "",
inputs=list(block.inputs) if hasattr(block, "inputs") else [],
outputs=list(block.intermediate_outputs) if hasattr(block, "intermediate_outputs") else [],
)
block_infos.append(info)
return block_infos
def _get_type_name(type_hint):
if type_hint is None:
return None
if hasattr(type_hint, "__name__"):
return type_hint.__name__
if hasattr(type_hint, "__module__") and hasattr(type_hint, "__qualname__"):
return f"{type_hint.__module__}.{type_hint.__qualname__}"
return str(type_hint)
def _is_internal_type(type_hint):
if type_hint is None:
return True
type_name = _get_type_name(type_hint)
if type_name is None:
return True
if type_name in INTERNAL_TYPE_NAMES or type_name in INTERNAL_TYPE_FULL_NAMES:
return True
type_str = str(type_hint)
for full_name in INTERNAL_TYPE_FULL_NAMES:
if full_name in type_str:
return True
if type_str.startswith("dict[") or type_str == "dict":
return True
return False
def _type_hint_to_gradio(type_hint, param_name, default=None):
if _is_internal_type(type_hint):
return None
if param_name in SLIDER_PARAMS:
slider_opts = SLIDER_PARAMS[param_name]
val = default if default is not None else slider_opts.get("minimum", 0)
return (
f'gr.Slider(label="{param_name}", value={val!r}, '
f"minimum={slider_opts['minimum']}, maximum={slider_opts['maximum']}, "
f"step={slider_opts['step']})"
)
type_name = _get_type_name(type_hint)
type_str = str(type_hint)
if type_name == "str" or type_hint is str:
lines = 3 if "prompt" in param_name else 1
default_repr = f", value={default!r}" if default is not None else ""
return f'gr.Textbox(label="{param_name}", lines={lines}{default_repr})'
if type_name == "int" or type_hint is int:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}", precision=0{val})'
if type_name == "float" or type_hint is float:
val = f", value={default!r}" if default is not None else ""
return f'gr.Number(label="{param_name}"{val})'
if type_name == "bool" or type_hint is bool:
val = default if default is not None else False
return f'gr.Checkbox(label="{param_name}", value={val!r})'
if "Image" in type_str:
if "list" in type_str.lower():
return f'gr.Gallery(label="{param_name}")'
return f'gr.Image(label="{param_name}")'
if default is not None:
return f'gr.Textbox(label="{param_name}", value={default!r})'
return f'gr.Textbox(label="{param_name}")'
def _output_type_to_gradio(type_hint, param_name):
if _is_internal_type(type_hint):
return None
type_str = str(type_hint)
if "Image" in type_str:
if "list" in type_str.lower():
return f'gr.Gallery(label="{param_name}")'
return f'gr.Image(label="{param_name}")'
if type_hint is str:
return f'gr.Textbox(label="{param_name}")'
if type_hint is int or type_hint is float:
return f'gr.Number(label="{param_name}")'
return None
def _classify_inputs(block_infos):
all_prior_outputs = {}
for info in block_infos:
user_inputs = []
port_connections = []
fixed_inputs = []
for inp in info.inputs:
if inp.name is None:
continue
if inp.name in all_prior_outputs:
port_connections.append((inp.name, all_prior_outputs[inp.name]))
elif _is_internal_type(inp.type_hint):
fixed_inputs.append(inp)
else:
user_inputs.append(inp)
info.user_inputs = user_inputs
info.port_connections = port_connections
info.fixed_inputs = fixed_inputs
for out in info.outputs:
if out.name and out.name not in all_prior_outputs:
all_prior_outputs[out.name] = info.name
def _sanitize_name(name):
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
if sanitized and sanitized[0].isdigit():
sanitized = f"_{sanitized}"
return sanitized
def _generate_code(block_infos, repo_id, blocks_class_name, workflow_label, workflow_resolve_code):
lines = []
lines.append(f'"""Daggr app for {blocks_class_name} ({workflow_label} workflow)')
lines.append("Generated by: diffusers-cli daggr")
lines.append('"""')
lines.append("")
lines.append("import gradio as gr")
lines.append("from daggr import FnNode, InputNode, Graph")
lines.append("")
lines.append("")
# Pipeline and resolved blocks loader
lines.append("_pipeline = None")
lines.append("_exec_blocks = None")
lines.append("")
lines.append("")
lines.append("def _get_pipeline():")
lines.append(" global _pipeline, _exec_blocks")
lines.append(" if _pipeline is None:")
lines.append(" from diffusers import ModularPipeline")
lines.append(f" _pipeline = ModularPipeline.from_pretrained({repo_id!r}, trust_remote_code=True)")
lines.append(" _pipeline.load_components()")
lines.append(f" _exec_blocks = {workflow_resolve_code}")
lines.append(" return _pipeline, _exec_blocks")
lines.append("")
lines.append("")
# Wrapper functions
for info in block_infos:
fn_name = f"run_{_sanitize_name(info.name)}"
all_input_names = []
for inp in info.inputs:
if inp.name is not None:
all_input_names.append(inp.name)
params = ", ".join(all_input_names)
lines.append(f"def {fn_name}({params}):")
lines.append(" from diffusers.modular_pipelines.modular_pipeline import PipelineState")
lines.append("")
lines.append(" pipe, exec_blocks = _get_pipeline()")
lines.append(" state = PipelineState()")
for inp_name in all_input_names:
lines.append(f' state.set("{inp_name}", {inp_name})')
lines.append(f' block = exec_blocks.sub_blocks["{info.name}"]')
lines.append(" _, state = block(pipe, state)")
if len(info.outputs) == 0:
lines.append(" return None")
elif len(info.outputs) == 1:
out = info.outputs[0]
lines.append(f' return state.get("{out.name}")')
else:
out_names = [out.name for out in info.outputs]
out_dict = ", ".join(f'"{n}": state.get("{n}")' for n in out_names)
lines.append(f" return {{{out_dict}}}")
lines.append("")
lines.append("")
# Collect all user-facing inputs across blocks
all_user_inputs = OrderedDict()
for info in block_infos:
for inp in info.user_inputs:
if inp.name not in all_user_inputs:
all_user_inputs[inp.name] = inp
# InputNode
if all_user_inputs:
lines.append("# -- User Inputs --")
lines.append('user_inputs = InputNode("User Inputs", ports={')
for inp_name, inp in all_user_inputs.items():
gradio_comp = _type_hint_to_gradio(inp.type_hint, inp_name, inp.default)
if gradio_comp:
lines.append(f' "{inp_name}": {gradio_comp},')
lines.append("})")
lines.append("")
lines.append("")
# FnNode definitions
lines.append("# -- Pipeline Blocks --")
node_var_names = {}
for info in block_infos:
var_name = f"{_sanitize_name(info.name)}_node"
node_var_names[info.name] = var_name
fn_name = f"run_{_sanitize_name(info.name)}"
display_name = info.name.replace("_", " ").replace(".", " > ").title()
# Build inputs dict
input_entries = []
for inp in info.inputs:
if inp.name is None:
continue
connected = False
for conn_name, source_block in info.port_connections:
if conn_name == inp.name:
source_var = node_var_names[source_block]
input_entries.append(f' "{inp.name}": {source_var}.{inp.name},')
connected = True
break
if not connected:
if inp.name in all_user_inputs:
input_entries.append(f' "{inp.name}": user_inputs.{inp.name},')
elif inp.default is not None:
input_entries.append(f' "{inp.name}": {inp.default!r},')
else:
input_entries.append(f' "{inp.name}": None,')
# Build outputs dict
output_entries = []
for out in info.outputs:
gradio_out = _output_type_to_gradio(out.type_hint, out.name)
if gradio_out:
output_entries.append(f' "{out.name}": {gradio_out},')
else:
output_entries.append(f' "{out.name}": None,')
lines.append(f"{var_name} = FnNode(")
lines.append(f" fn={fn_name},")
lines.append(f' name="{display_name}",')
if input_entries:
lines.append(" inputs={")
lines.extend(input_entries)
lines.append(" },")
if output_entries:
lines.append(" outputs={")
lines.extend(output_entries)
lines.append(" },")
lines.append(")")
lines.append("")
# Graph
lines.append("")
lines.append("# -- Graph --")
all_node_vars = []
if all_user_inputs:
all_node_vars.append("user_inputs")
all_node_vars.extend(node_var_names[info.name] for info in block_infos)
graph_name = f"{blocks_class_name} - {workflow_label}"
nodes_str = ", ".join(all_node_vars)
lines.append(f'graph = Graph("{graph_name}", nodes=[{nodes_str}])')
lines.append("graph.launch()")
lines.append("")
return "\n".join(lines)

View File

@@ -16,6 +16,7 @@
from argparse import ArgumentParser
from .custom_blocks import CustomBlocksCommand
from .daggr_app import DaggrCommand
from .env import EnvironmentCommand
from .fp16_safetensors import FP16SafetensorsCommand
@@ -28,6 +29,7 @@ def main():
EnvironmentCommand.register_subcommand(commands_parser)
FP16SafetensorsCommand.register_subcommand(commands_parser)
CustomBlocksCommand.register_subcommand(commands_parser)
DaggrCommand.register_subcommand(commands_parser)
# Let's go
args = parser.parse_args()

View File

@@ -60,16 +60,6 @@ class ContextParallelConfig:
rotate_method (`str`, *optional*, defaults to `"allgather"`):
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
is supported.
ulysses_anything (`bool`, *optional*, defaults to `False`):
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
`ring_degree` must be 1.
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
creating a new one. This is useful when combining context parallelism with other parallelism strategies
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
"""
@@ -78,7 +68,6 @@ class ContextParallelConfig:
convert_to_fp32: bool = True
# TODO: support alltoall
rotate_method: Literal["allgather", "alltoall"] = "allgather"
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
# Whether to enable ulysses anything attention to support
# any sequence lengths and any head numbers.
ulysses_anything: bool = False
@@ -135,7 +124,7 @@ class ContextParallelConfig:
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()

View File

@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,

View File

@@ -95,7 +95,6 @@ from .pag import (
StableDiffusionXLPAGPipeline,
)
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,
@@ -186,7 +185,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
("z-image-omni", ZImageOmniPipeline),
("ovis", OvisImagePipeline),
("prx", PRXPipeline),
]
)

View File

@@ -82,16 +82,13 @@ EXAMPLE_DOC_STRING = """
```python
>>> import cv2
>>> import numpy as np
>>> from PIL import Image
>>> import torch
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
>>> from diffusers.utils import export_to_video, load_video
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
>>> controlnet = AutoModel.from_pretrained(
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
... )
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
... )

View File

@@ -60,7 +60,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
model.eval()
# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
inputs_on_device = {}
for key, value in inputs_dict.items():
if isinstance(value, torch.Tensor):
inputs_on_device[key] = value.to(device)
else:
inputs_on_device[key] = value
# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
@@ -84,59 +89,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()
def _custom_mesh_worker(
rank,
world_size,
master_port,
model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
):
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
try:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
model = model_class(**init_dict)
model.to(device)
model.eval()
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
# DeviceMesh must be created after init_process_group, inside each worker process.
mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
)
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
model.enable_parallelism(config=cp_config)
with torch.no_grad():
output = model(**inputs_on_device, return_dict=False)[0]
if rank == 0:
return_dict["status"] = "success"
return_dict["output_shape"] = list(output.shape)
except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()
@is_context_parallel
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@@ -174,48 +126,3 @@ class ContextParallelTesterMixin:
assert return_dict.get("status") == "success", (
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
],
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
)
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
cp_dict = {cp_type: world_size}
master_port = _find_free_port()
manager = mp.Manager()
return_dict = manager.dict()
mp.spawn(
_custom_mesh_worker,
args=(
world_size,
master_port,
self.model_class,
init_dict,
cp_dict,
mesh_shape,
mesh_dim_names,
inputs_dict,
return_dict,
),
nprocs=world_size,
join=True,
)
assert return_dict.get("status") == "success", (
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)

View File

@@ -1,242 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ConditionalPipelineBlocks,
InputParam,
ModularPipelineBlocks,
)
class TextToImageBlock(ModularPipelineBlocks):
model_name = "text2img"
@property
def inputs(self):
return [InputParam(name="prompt")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "text-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "text2img"
self.set_block_state(state, block_state)
return components, state
class ImageToImageBlock(ModularPipelineBlocks):
model_name = "img2img"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "image-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "img2img"
self.set_block_state(state, block_state)
return components, state
class InpaintBlock(ModularPipelineBlocks):
model_name = "inpaint"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "inpaint workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "inpaint"
self.set_block_state(state, block_state)
return components, state
class ConditionalImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return "Conditional image blocks for testing"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = None # no default; block can be skipped
@property
def description(self):
return "Optional conditional blocks (skippable)"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None
class AutoImageBlocks(AutoPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image", None]
@property
def description(self):
return "Auto image blocks for testing"
class TestConditionalPipelineBlocksSelectBlock:
def test_select_block_with_mask(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="something") == "inpaint"
def test_select_block_with_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(image="something") == "img2img"
def test_select_block_with_mask_and_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
def test_select_block_no_triggers_returns_none(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block() is None
def test_select_block_explicit_none_values(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask=None, image=None) is None
class TestConditionalPipelineBlocksWorkflowSelection:
def test_default_workflow_when_no_triggers(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks()
assert execution is not None
assert isinstance(execution, TextToImageBlock)
def test_mask_trigger_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_image_trigger_selects_img2img(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
def test_mask_and_image_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True, image=True)
assert isinstance(execution, InpaintBlock)
def test_skippable_block_returns_none(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks()
assert execution is None
def test_skippable_block_still_selects_when_triggered(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestAutoPipelineBlocksSelectBlock:
def test_auto_select_mask(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m") == "inpaint"
def test_auto_select_image(self):
blocks = AutoImageBlocks()
assert blocks.select_block(image="i") == "img2img"
def test_auto_select_default(self):
blocks = AutoImageBlocks()
# No trigger -> returns None -> falls back to default (text2img)
assert blocks.select_block() is None
def test_auto_select_priority_order(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
class TestAutoPipelineBlocksWorkflowSelection:
def test_auto_default_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks()
assert isinstance(execution, TextToImageBlock)
def test_auto_mask_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_auto_image_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestConditionalPipelineBlocksStructure:
def test_block_names_accessible(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
def test_sub_block_types(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert isinstance(sub["inpaint"], InpaintBlock)
assert isinstance(sub["img2img"], ImageToImageBlock)
assert isinstance(sub["text2img"], TextToImageBlock)
def test_description(self):
blocks = ConditionalImageBlocks()
assert "Conditional" in blocks.description

View File

@@ -9,6 +9,11 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -19,6 +24,7 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -431,6 +437,117 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:

View File

@@ -24,18 +24,14 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
ConditionalPipelineBlocks,
InputParam,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
OutputParam,
PipelineState,
SequentialPipelineBlocks,
WanModularPipeline,
)
from diffusers.utils import logging
from ..testing_utils import CaptureLogger, nightly, require_torch, slow
from ..testing_utils import nightly, require_torch, slow
class DummyCustomBlockSimple(ModularPipelineBlocks):
@@ -358,117 +354,6 @@ class TestModularCustomBlocks:
assert output_prompt.startswith("Modular diffusers + ")
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
@slow
@nightly
@require_torch