mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-14 12:37:59 +08:00
Compare commits
3 Commits
tests-load
...
modular-da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4517b9311a | ||
|
|
7e3e640b5a | ||
|
|
7b961f07e7 |
@@ -532,6 +532,8 @@
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
@@ -675,8 +677,6 @@
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/helios
|
||||
|
||||
@@ -21,31 +21,29 @@
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
## Basic usage
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2_5_PredictBasePipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
|
||||
)
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
image=None,
|
||||
video=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=93,
|
||||
generator=torch.Generator().manual_seed(1),
|
||||
).frames[0]
|
||||
export_to_video(output, "text2world.mp4", fps=16)
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## Cosmos2_5_TransferPipeline
|
||||
|
||||
@@ -44,7 +44,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
|
||||
| [ControlNet-XS](controlnetxs) | text2image |
|
||||
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
|
||||
| [Cosmos](cosmos) | text2video, video2video |
|
||||
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
|
||||
| [DDIM](ddim) | unconditional image generation |
|
||||
| [DDPM](ddpm) | unconditional image generation |
|
||||
|
||||
447
src/diffusers/commands/daggr_app.py
Normal file
447
src/diffusers/commands/daggr_app.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import re
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
logger = logging.get_logger("diffusers-cli/daggr")
|
||||
|
||||
INTERNAL_TYPE_NAMES = {
|
||||
"Tensor",
|
||||
"Generator",
|
||||
}
|
||||
|
||||
INTERNAL_TYPE_FULL_NAMES = {
|
||||
"torch.Tensor",
|
||||
"torch.Generator",
|
||||
"torch.dtype",
|
||||
}
|
||||
|
||||
SLIDER_PARAMS = {
|
||||
"height": {"minimum": 256, "maximum": 2048, "step": 64},
|
||||
"width": {"minimum": 256, "maximum": 2048, "step": 64},
|
||||
"num_inference_steps": {"minimum": 1, "maximum": 100, "step": 1},
|
||||
"guidance_scale": {"minimum": 0, "maximum": 30, "step": 0.5},
|
||||
"strength": {"minimum": 0, "maximum": 1, "step": 0.05},
|
||||
"control_guidance_start": {"minimum": 0, "maximum": 1, "step": 0.05},
|
||||
"control_guidance_end": {"minimum": 0, "maximum": 1, "step": 0.05},
|
||||
"controlnet_conditioning_scale": {"minimum": 0, "maximum": 2, "step": 0.1},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockInfo:
|
||||
name: str
|
||||
class_name: str
|
||||
description: str
|
||||
inputs: list
|
||||
outputs: list
|
||||
user_inputs: list = field(default_factory=list)
|
||||
port_connections: list = field(default_factory=list)
|
||||
fixed_inputs: list = field(default_factory=list)
|
||||
|
||||
|
||||
def daggr_command_factory(args: Namespace):
|
||||
return DaggrCommand(
|
||||
repo_id=args.repo_id,
|
||||
output=args.output or "daggr_app.py",
|
||||
workflow=getattr(args, "workflow", None),
|
||||
trigger_inputs=getattr(args, "trigger_inputs", None),
|
||||
)
|
||||
|
||||
|
||||
class DaggrCommand(BaseDiffusersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
daggr_parser = parser.add_parser("daggr", help="Generate a daggr app from a modular pipeline repo.")
|
||||
daggr_parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="HuggingFace Hub repo ID containing a modular pipeline (with modular_model_index.json).",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="daggr_app.py",
|
||||
help="Output file path for the generated daggr app. Default: daggr_app.py",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--workflow",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Named workflow to resolve conditional blocks (e.g. 'text2image', 'image2image').",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--trigger-inputs",
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Trigger input names for manual conditional resolution.",
|
||||
)
|
||||
daggr_parser.set_defaults(func=daggr_command_factory)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
output: str = "daggr_app.py",
|
||||
workflow: str | None = None,
|
||||
trigger_inputs: list | None = None,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.output = output
|
||||
self.workflow = workflow
|
||||
self.trigger_inputs = trigger_inputs
|
||||
|
||||
def run(self):
|
||||
from ..modular_pipelines.modular_pipeline import ModularPipelineBlocks
|
||||
|
||||
logger.info(f"Loading blocks from {self.repo_id}...")
|
||||
blocks = ModularPipelineBlocks.from_pretrained(self.repo_id, trust_remote_code=True)
|
||||
blocks_class_name = blocks.__class__.__name__
|
||||
|
||||
if self.workflow:
|
||||
logger.info(f"Resolving workflow: {self.workflow}")
|
||||
exec_blocks = blocks.get_workflow(self.workflow)
|
||||
elif self.trigger_inputs:
|
||||
trigger_kwargs = {name: True for name in self.trigger_inputs}
|
||||
logger.info(f"Resolving with trigger inputs: {self.trigger_inputs}")
|
||||
exec_blocks = blocks.get_execution_blocks(**trigger_kwargs)
|
||||
else:
|
||||
logger.info("Resolving default execution blocks...")
|
||||
exec_blocks = blocks.get_execution_blocks()
|
||||
|
||||
block_infos = _analyze_blocks(exec_blocks)
|
||||
_classify_inputs(block_infos)
|
||||
|
||||
workflow_label = self.workflow or "default"
|
||||
workflow_resolve_code = self._get_workflow_resolve_code()
|
||||
code = _generate_code(block_infos, self.repo_id, blocks_class_name, workflow_label, workflow_resolve_code)
|
||||
|
||||
try:
|
||||
ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
logger.warning(f"Generated code has syntax error: {e}")
|
||||
|
||||
with open(self.output, "w") as f:
|
||||
f.write(code)
|
||||
|
||||
logger.info(f"Daggr app written to {self.output}")
|
||||
print(f"Generated daggr app: {self.output}")
|
||||
print(f" Pipeline: {blocks_class_name}")
|
||||
print(f" Workflow: {workflow_label}")
|
||||
print(f" Blocks: {len(block_infos)}")
|
||||
print(f"\nRun with: python {self.output}")
|
||||
|
||||
def _get_workflow_resolve_code(self):
|
||||
if self.workflow:
|
||||
return f"_pipeline._blocks.get_workflow({self.workflow!r})"
|
||||
elif self.trigger_inputs:
|
||||
kwargs_str = ", ".join(f"{name!r}: True" for name in self.trigger_inputs)
|
||||
return f"_pipeline._blocks.get_execution_blocks(**{{{kwargs_str}}})"
|
||||
else:
|
||||
return "_pipeline._blocks.get_execution_blocks()"
|
||||
|
||||
|
||||
def _analyze_blocks(exec_blocks):
|
||||
block_infos = []
|
||||
for name, block in exec_blocks.sub_blocks.items():
|
||||
info = BlockInfo(
|
||||
name=name,
|
||||
class_name=block.__class__.__name__,
|
||||
description=getattr(block, "description", "") or "",
|
||||
inputs=list(block.inputs) if hasattr(block, "inputs") else [],
|
||||
outputs=list(block.intermediate_outputs) if hasattr(block, "intermediate_outputs") else [],
|
||||
)
|
||||
block_infos.append(info)
|
||||
return block_infos
|
||||
|
||||
|
||||
def _get_type_name(type_hint):
|
||||
if type_hint is None:
|
||||
return None
|
||||
if hasattr(type_hint, "__name__"):
|
||||
return type_hint.__name__
|
||||
if hasattr(type_hint, "__module__") and hasattr(type_hint, "__qualname__"):
|
||||
return f"{type_hint.__module__}.{type_hint.__qualname__}"
|
||||
return str(type_hint)
|
||||
|
||||
|
||||
def _is_internal_type(type_hint):
|
||||
if type_hint is None:
|
||||
return True
|
||||
type_name = _get_type_name(type_hint)
|
||||
if type_name is None:
|
||||
return True
|
||||
if type_name in INTERNAL_TYPE_NAMES or type_name in INTERNAL_TYPE_FULL_NAMES:
|
||||
return True
|
||||
type_str = str(type_hint)
|
||||
for full_name in INTERNAL_TYPE_FULL_NAMES:
|
||||
if full_name in type_str:
|
||||
return True
|
||||
if type_str.startswith("dict[") or type_str == "dict":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _type_hint_to_gradio(type_hint, param_name, default=None):
|
||||
if _is_internal_type(type_hint):
|
||||
return None
|
||||
|
||||
if param_name in SLIDER_PARAMS:
|
||||
slider_opts = SLIDER_PARAMS[param_name]
|
||||
val = default if default is not None else slider_opts.get("minimum", 0)
|
||||
return (
|
||||
f'gr.Slider(label="{param_name}", value={val!r}, '
|
||||
f"minimum={slider_opts['minimum']}, maximum={slider_opts['maximum']}, "
|
||||
f"step={slider_opts['step']})"
|
||||
)
|
||||
|
||||
type_name = _get_type_name(type_hint)
|
||||
type_str = str(type_hint)
|
||||
|
||||
if type_name == "str" or type_hint is str:
|
||||
lines = 3 if "prompt" in param_name else 1
|
||||
default_repr = f", value={default!r}" if default is not None else ""
|
||||
return f'gr.Textbox(label="{param_name}", lines={lines}{default_repr})'
|
||||
|
||||
if type_name == "int" or type_hint is int:
|
||||
val = f", value={default!r}" if default is not None else ""
|
||||
return f'gr.Number(label="{param_name}", precision=0{val})'
|
||||
|
||||
if type_name == "float" or type_hint is float:
|
||||
val = f", value={default!r}" if default is not None else ""
|
||||
return f'gr.Number(label="{param_name}"{val})'
|
||||
|
||||
if type_name == "bool" or type_hint is bool:
|
||||
val = default if default is not None else False
|
||||
return f'gr.Checkbox(label="{param_name}", value={val!r})'
|
||||
|
||||
if "Image" in type_str:
|
||||
if "list" in type_str.lower():
|
||||
return f'gr.Gallery(label="{param_name}")'
|
||||
return f'gr.Image(label="{param_name}")'
|
||||
|
||||
if default is not None:
|
||||
return f'gr.Textbox(label="{param_name}", value={default!r})'
|
||||
|
||||
return f'gr.Textbox(label="{param_name}")'
|
||||
|
||||
|
||||
def _output_type_to_gradio(type_hint, param_name):
|
||||
if _is_internal_type(type_hint):
|
||||
return None
|
||||
type_str = str(type_hint)
|
||||
if "Image" in type_str:
|
||||
if "list" in type_str.lower():
|
||||
return f'gr.Gallery(label="{param_name}")'
|
||||
return f'gr.Image(label="{param_name}")'
|
||||
if type_hint is str:
|
||||
return f'gr.Textbox(label="{param_name}")'
|
||||
if type_hint is int or type_hint is float:
|
||||
return f'gr.Number(label="{param_name}")'
|
||||
return None
|
||||
|
||||
|
||||
def _classify_inputs(block_infos):
|
||||
all_prior_outputs = {}
|
||||
|
||||
for info in block_infos:
|
||||
user_inputs = []
|
||||
port_connections = []
|
||||
fixed_inputs = []
|
||||
|
||||
for inp in info.inputs:
|
||||
if inp.name is None:
|
||||
continue
|
||||
if inp.name in all_prior_outputs:
|
||||
port_connections.append((inp.name, all_prior_outputs[inp.name]))
|
||||
elif _is_internal_type(inp.type_hint):
|
||||
fixed_inputs.append(inp)
|
||||
else:
|
||||
user_inputs.append(inp)
|
||||
|
||||
info.user_inputs = user_inputs
|
||||
info.port_connections = port_connections
|
||||
info.fixed_inputs = fixed_inputs
|
||||
|
||||
for out in info.outputs:
|
||||
if out.name and out.name not in all_prior_outputs:
|
||||
all_prior_outputs[out.name] = info.name
|
||||
|
||||
|
||||
def _sanitize_name(name):
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
if sanitized and sanitized[0].isdigit():
|
||||
sanitized = f"_{sanitized}"
|
||||
return sanitized
|
||||
|
||||
|
||||
def _generate_code(block_infos, repo_id, blocks_class_name, workflow_label, workflow_resolve_code):
|
||||
lines = []
|
||||
|
||||
lines.append(f'"""Daggr app for {blocks_class_name} ({workflow_label} workflow)')
|
||||
lines.append("Generated by: diffusers-cli daggr")
|
||||
lines.append('"""')
|
||||
lines.append("")
|
||||
lines.append("import gradio as gr")
|
||||
lines.append("from daggr import FnNode, InputNode, Graph")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
# Pipeline and resolved blocks loader
|
||||
lines.append("_pipeline = None")
|
||||
lines.append("_exec_blocks = None")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
lines.append("def _get_pipeline():")
|
||||
lines.append(" global _pipeline, _exec_blocks")
|
||||
lines.append(" if _pipeline is None:")
|
||||
lines.append(" from diffusers import ModularPipeline")
|
||||
lines.append(f" _pipeline = ModularPipeline.from_pretrained({repo_id!r}, trust_remote_code=True)")
|
||||
lines.append(" _pipeline.load_components()")
|
||||
lines.append(f" _exec_blocks = {workflow_resolve_code}")
|
||||
lines.append(" return _pipeline, _exec_blocks")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
# Wrapper functions
|
||||
for info in block_infos:
|
||||
fn_name = f"run_{_sanitize_name(info.name)}"
|
||||
all_input_names = []
|
||||
for inp in info.inputs:
|
||||
if inp.name is not None:
|
||||
all_input_names.append(inp.name)
|
||||
|
||||
params = ", ".join(all_input_names)
|
||||
lines.append(f"def {fn_name}({params}):")
|
||||
lines.append(" from diffusers.modular_pipelines.modular_pipeline import PipelineState")
|
||||
lines.append("")
|
||||
lines.append(" pipe, exec_blocks = _get_pipeline()")
|
||||
lines.append(" state = PipelineState()")
|
||||
for inp_name in all_input_names:
|
||||
lines.append(f' state.set("{inp_name}", {inp_name})')
|
||||
lines.append(f' block = exec_blocks.sub_blocks["{info.name}"]')
|
||||
lines.append(" _, state = block(pipe, state)")
|
||||
|
||||
if len(info.outputs) == 0:
|
||||
lines.append(" return None")
|
||||
elif len(info.outputs) == 1:
|
||||
out = info.outputs[0]
|
||||
lines.append(f' return state.get("{out.name}")')
|
||||
else:
|
||||
out_names = [out.name for out in info.outputs]
|
||||
out_dict = ", ".join(f'"{n}": state.get("{n}")' for n in out_names)
|
||||
lines.append(f" return {{{out_dict}}}")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
# Collect all user-facing inputs across blocks
|
||||
all_user_inputs = OrderedDict()
|
||||
for info in block_infos:
|
||||
for inp in info.user_inputs:
|
||||
if inp.name not in all_user_inputs:
|
||||
all_user_inputs[inp.name] = inp
|
||||
|
||||
# InputNode
|
||||
if all_user_inputs:
|
||||
lines.append("# -- User Inputs --")
|
||||
lines.append('user_inputs = InputNode("User Inputs", ports={')
|
||||
for inp_name, inp in all_user_inputs.items():
|
||||
gradio_comp = _type_hint_to_gradio(inp.type_hint, inp_name, inp.default)
|
||||
if gradio_comp:
|
||||
lines.append(f' "{inp_name}": {gradio_comp},')
|
||||
lines.append("})")
|
||||
lines.append("")
|
||||
lines.append("")
|
||||
|
||||
# FnNode definitions
|
||||
lines.append("# -- Pipeline Blocks --")
|
||||
node_var_names = {}
|
||||
|
||||
for info in block_infos:
|
||||
var_name = f"{_sanitize_name(info.name)}_node"
|
||||
node_var_names[info.name] = var_name
|
||||
fn_name = f"run_{_sanitize_name(info.name)}"
|
||||
|
||||
display_name = info.name.replace("_", " ").replace(".", " > ").title()
|
||||
|
||||
# Build inputs dict
|
||||
input_entries = []
|
||||
for inp in info.inputs:
|
||||
if inp.name is None:
|
||||
continue
|
||||
|
||||
connected = False
|
||||
for conn_name, source_block in info.port_connections:
|
||||
if conn_name == inp.name:
|
||||
source_var = node_var_names[source_block]
|
||||
input_entries.append(f' "{inp.name}": {source_var}.{inp.name},')
|
||||
connected = True
|
||||
break
|
||||
|
||||
if not connected:
|
||||
if inp.name in all_user_inputs:
|
||||
input_entries.append(f' "{inp.name}": user_inputs.{inp.name},')
|
||||
elif inp.default is not None:
|
||||
input_entries.append(f' "{inp.name}": {inp.default!r},')
|
||||
else:
|
||||
input_entries.append(f' "{inp.name}": None,')
|
||||
|
||||
# Build outputs dict
|
||||
output_entries = []
|
||||
for out in info.outputs:
|
||||
gradio_out = _output_type_to_gradio(out.type_hint, out.name)
|
||||
if gradio_out:
|
||||
output_entries.append(f' "{out.name}": {gradio_out},')
|
||||
else:
|
||||
output_entries.append(f' "{out.name}": None,')
|
||||
|
||||
lines.append(f"{var_name} = FnNode(")
|
||||
lines.append(f" fn={fn_name},")
|
||||
lines.append(f' name="{display_name}",')
|
||||
|
||||
if input_entries:
|
||||
lines.append(" inputs={")
|
||||
lines.extend(input_entries)
|
||||
lines.append(" },")
|
||||
|
||||
if output_entries:
|
||||
lines.append(" outputs={")
|
||||
lines.extend(output_entries)
|
||||
lines.append(" },")
|
||||
|
||||
lines.append(")")
|
||||
lines.append("")
|
||||
|
||||
# Graph
|
||||
lines.append("")
|
||||
lines.append("# -- Graph --")
|
||||
all_node_vars = []
|
||||
if all_user_inputs:
|
||||
all_node_vars.append("user_inputs")
|
||||
all_node_vars.extend(node_var_names[info.name] for info in block_infos)
|
||||
|
||||
graph_name = f"{blocks_class_name} - {workflow_label}"
|
||||
nodes_str = ", ".join(all_node_vars)
|
||||
lines.append(f'graph = Graph("{graph_name}", nodes=[{nodes_str}])')
|
||||
lines.append("graph.launch()")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -16,6 +16,7 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from .custom_blocks import CustomBlocksCommand
|
||||
from .daggr_app import DaggrCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .fp16_safetensors import FP16SafetensorsCommand
|
||||
|
||||
@@ -28,6 +29,7 @@ def main():
|
||||
EnvironmentCommand.register_subcommand(commands_parser)
|
||||
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
||||
CustomBlocksCommand.register_subcommand(commands_parser)
|
||||
DaggrCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -60,16 +60,6 @@ class ContextParallelConfig:
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
ulysses_anything (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
|
||||
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
|
||||
`ring_degree` must be 1.
|
||||
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
|
||||
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
|
||||
creating a new one. This is useful when combining context parallelism with other parallelism strategies
|
||||
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
|
||||
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
|
||||
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
|
||||
|
||||
"""
|
||||
|
||||
@@ -78,7 +68,6 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
@@ -135,7 +124,7 @@ class ContextParallelConfig:
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
|
||||
@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
|
||||
@@ -95,7 +95,6 @@ from .pag import (
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
@@ -186,7 +185,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
("prx", PRXPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -82,16 +82,13 @@ EXAMPLE_DOC_STRING = """
|
||||
```python
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
|
||||
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
|
||||
>>> controlnet = AutoModel.from_pretrained(
|
||||
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
|
||||
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
|
||||
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -60,7 +60,12 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
@@ -84,59 +89,6 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
master_port,
|
||||
model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
):
|
||||
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# DeviceMesh must be created after init_process_group, inside each worker process.
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@@ -174,48 +126,3 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
|
||||
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
|
||||
],
|
||||
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
|
||||
)
|
||||
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
master_port = _find_free_port()
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
mp.spawn(
|
||||
_custom_mesh_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -33,33 +32,6 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -388,39 +360,6 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
Reference in New Issue
Block a user