mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-10 01:37:08 +08:00
Compare commits
6 Commits
hunyuan-te
...
modular-da
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d75d0b6773 | ||
|
|
b142d63fa8 | ||
|
|
46ec9434a4 | ||
|
|
4517b9311a | ||
|
|
7e3e640b5a | ||
|
|
7b961f07e7 |
@@ -2,6 +2,7 @@
|
||||
line-length = 119
|
||||
extend-exclude = [
|
||||
"src/diffusers/pipelines/flux2/system_messages.py",
|
||||
"src/diffusers/commands/daggr_app.py",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
726
src/diffusers/commands/daggr_app.py
Normal file
726
src/diffusers/commands/daggr_app.py
Normal file
@@ -0,0 +1,726 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import re
|
||||
import tempfile
|
||||
import types
|
||||
import typing
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
logger = logging.get_logger("diffusers-cli/daggr")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INTERNAL_TYPE_NAMES = {
|
||||
"Tensor",
|
||||
"Generator",
|
||||
}
|
||||
|
||||
INTERNAL_TYPE_FULL_NAMES = {
|
||||
"torch.Tensor",
|
||||
"torch.Generator",
|
||||
"torch.dtype",
|
||||
}
|
||||
|
||||
SLIDER_PARAMS = {
|
||||
"height": {"minimum": 256, "maximum": 2048, "step": 64},
|
||||
"width": {"minimum": 256, "maximum": 2048, "step": 64},
|
||||
"num_inference_steps": {"minimum": 1, "maximum": 100, "step": 1},
|
||||
"guidance_scale": {"minimum": 0, "maximum": 30, "step": 0.5},
|
||||
"strength": {"minimum": 0, "maximum": 1, "step": 0.05},
|
||||
"control_guidance_start": {"minimum": 0, "maximum": 1, "step": 0.05},
|
||||
"control_guidance_end": {"minimum": 0, "maximum": 1, "step": 0.05},
|
||||
"controlnet_conditioning_scale": {"minimum": 0, "maximum": 2, "step": 0.1},
|
||||
}
|
||||
|
||||
DROPDOWN_PARAMS = {
|
||||
"output_type": {"choices": ["pil", "np", "pt", "latent"], "value": "pil"},
|
||||
}
|
||||
|
||||
DEFAULT_REQUIREMENTS = (
|
||||
"torch --index-url https://download.pytorch.org/whl/cu129\n"
|
||||
"diffusers\n"
|
||||
"transformers\n"
|
||||
"accelerate\n"
|
||||
"sentencepiece\n"
|
||||
"bitsandbytes\n"
|
||||
"daggr\n"
|
||||
"gradio\n"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code templates (string concat so formatters can't mangle newlines)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HEADER_TEMPLATE = (
|
||||
"import os\n"
|
||||
"{extra_imports}\n"
|
||||
"import gradio as gr\n"
|
||||
"from daggr import FnNode, InputNode, Graph\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"_pipeline = None\n"
|
||||
"_exec_blocks = None\n"
|
||||
"_tensor_store = {{}}\n"
|
||||
"\n"
|
||||
"\n"
|
||||
"def _get_pipeline():\n"
|
||||
" global _pipeline, _exec_blocks\n"
|
||||
" if _pipeline is None:\n"
|
||||
" from diffusers import ModularPipeline\n"
|
||||
" import torch\n"
|
||||
"\n"
|
||||
' _token = os.environ.get("HF_TOKEN")\n'
|
||||
' _device = "cuda" if torch.cuda.is_available() else "cpu"\n'
|
||||
" _pipeline = ModularPipeline.from_pretrained({repo_id}, trust_remote_code=True, token=_token)\n"
|
||||
" _pipeline.load_components(torch_dtype=torch.bfloat16, device_map=_device)\n"
|
||||
" _exec_blocks = {workflow_resolve_code}\n"
|
||||
" return _pipeline, _exec_blocks"
|
||||
)
|
||||
|
||||
_SAVE_IMAGE_TEMPLATE = (
|
||||
"\n"
|
||||
"\n"
|
||||
"def _save_image(val):\n"
|
||||
" from PIL import Image as PILImage\n"
|
||||
"\n"
|
||||
" if isinstance(val, list):\n"
|
||||
" paths = [_save_image(item) for item in val]\n"
|
||||
" return paths[0] if paths else None\n"
|
||||
" if isinstance(val, PILImage.Image):\n"
|
||||
' f = tempfile.NamedTemporaryFile(suffix=".png", delete=False)\n'
|
||||
" val.save(f.name)\n"
|
||||
" return f.name\n"
|
||||
" return val"
|
||||
)
|
||||
|
||||
_CACHE_DISABLE_TEMPLATE = (
|
||||
"\n"
|
||||
"# Disable result caching so every run executes all nodes fresh\n"
|
||||
"from daggr.state import SessionState\n"
|
||||
"SessionState.get_latest_result = lambda *a, **kw: None\n"
|
||||
"SessionState.get_result_by_index = lambda *a, **kw: None\n"
|
||||
"SessionState.save_result = lambda *a, **kw: None\n"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockInfo:
|
||||
name: str
|
||||
class_name: str
|
||||
description: str
|
||||
inputs: list
|
||||
outputs: list
|
||||
user_inputs: list = field(default_factory=list)
|
||||
port_connections: list = field(default_factory=list)
|
||||
fixed_inputs: list = field(default_factory=list)
|
||||
sub_block_names: list = field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI command
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def daggr_command_factory(args: Namespace):
|
||||
return DaggrCommand(
|
||||
repo_id=args.repo_id,
|
||||
output=args.output,
|
||||
workflow=getattr(args, "workflow", None),
|
||||
trigger_inputs=getattr(args, "trigger_inputs", None),
|
||||
deploy=getattr(args, "deploy", None),
|
||||
hardware=getattr(args, "hardware", "cpu-basic"),
|
||||
private=getattr(args, "private", False),
|
||||
requirements=getattr(args, "requirements", None),
|
||||
)
|
||||
|
||||
|
||||
class DaggrCommand(BaseDiffusersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
daggr_parser = parser.add_parser("daggr", help="Generate a daggr app from a modular pipeline repo.")
|
||||
daggr_parser.add_argument(
|
||||
"repo_id",
|
||||
type=str,
|
||||
help="HuggingFace Hub repo ID containing a modular pipeline.",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Save the generated app to a file instead of launching it directly.",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--workflow",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Named workflow to resolve conditional blocks (e.g. 'text2image', 'image2image').",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--trigger-inputs",
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="Trigger input names for manual conditional resolution.",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--deploy",
|
||||
type=str,
|
||||
default=None,
|
||||
metavar="SPACE_NAME",
|
||||
help="Deploy the generated app to a HuggingFace Space via daggr deploy.",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--hardware",
|
||||
type=str,
|
||||
default="cpu-basic",
|
||||
help="Hardware tier for the deployed Space (default: cpu-basic). E.g. a10g-small, a100-large.",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Make the deployed Space private.",
|
||||
)
|
||||
daggr_parser.add_argument(
|
||||
"--requirements",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a requirements.txt file for the deployed Space.",
|
||||
)
|
||||
daggr_parser.set_defaults(func=daggr_command_factory)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
output: str | None = None,
|
||||
workflow: str | None = None,
|
||||
trigger_inputs: list | None = None,
|
||||
deploy: str | None = None,
|
||||
hardware: str = "cpu-basic",
|
||||
private: bool = False,
|
||||
requirements: str | None = None,
|
||||
):
|
||||
self.repo_id = repo_id
|
||||
self.output = output
|
||||
self.workflow = workflow
|
||||
self.trigger_inputs = trigger_inputs
|
||||
self.deploy = deploy
|
||||
self.hardware = hardware
|
||||
self.private = private
|
||||
self.requirements = requirements
|
||||
|
||||
def run(self):
|
||||
from ..modular_pipelines.modular_pipeline import ModularPipeline
|
||||
|
||||
logger.info(f"Loading blocks from {self.repo_id}...")
|
||||
pipeline = ModularPipeline.from_pretrained(self.repo_id, trust_remote_code=True)
|
||||
blocks = pipeline._blocks
|
||||
blocks_class_name = blocks.__class__.__name__
|
||||
|
||||
if self.workflow:
|
||||
logger.info(f"Resolving workflow: {self.workflow}")
|
||||
exec_blocks = blocks.get_workflow(self.workflow)
|
||||
elif self.trigger_inputs:
|
||||
trigger_kwargs = {name: True for name in self.trigger_inputs}
|
||||
logger.info(f"Resolving with trigger inputs: {self.trigger_inputs}")
|
||||
exec_blocks = blocks.get_execution_blocks(**trigger_kwargs)
|
||||
else:
|
||||
logger.info("Resolving default execution blocks...")
|
||||
exec_blocks = blocks.get_execution_blocks()
|
||||
|
||||
block_infos = _get_block_info(exec_blocks)
|
||||
_filter_outputs(block_infos)
|
||||
_classify_inputs(block_infos)
|
||||
|
||||
workflow_label = self.workflow or "default"
|
||||
workflow_resolve_code = self._get_workflow_resolve_code()
|
||||
code = _generate_code(block_infos, self.repo_id, blocks_class_name, workflow_label, workflow_resolve_code)
|
||||
|
||||
try:
|
||||
ast.parse(code)
|
||||
except SyntaxError as e:
|
||||
logger.warning(f"Generated code has syntax error: {e}")
|
||||
|
||||
if self.deploy:
|
||||
self._deploy_to_space(code)
|
||||
elif self.output:
|
||||
with open(self.output, "w") as f:
|
||||
f.write(code)
|
||||
print(f"Generated daggr app: {self.output}")
|
||||
print(f" Pipeline: {blocks_class_name}")
|
||||
print(f" Workflow: {workflow_label}")
|
||||
print(f" Blocks: {len(block_infos)}")
|
||||
print(f"\nRun with: python {self.output}")
|
||||
else:
|
||||
print(f"Launching daggr app for {blocks_class_name} ({workflow_label} workflow)...")
|
||||
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, prefix="daggr_")
|
||||
tmp.write(code)
|
||||
tmp.close()
|
||||
logger.info(f"Generated temp script: {tmp.name}")
|
||||
exec(compile(code, tmp.name, "exec"), {"__name__": "__main__"})
|
||||
|
||||
def _deploy_to_space(self, code):
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from daggr.cli import _deploy
|
||||
|
||||
tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False, prefix="daggr_")
|
||||
tmp.write(code)
|
||||
tmp.close()
|
||||
|
||||
req_path = self.requirements
|
||||
if not req_path:
|
||||
req_file = tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, prefix="daggr_req_")
|
||||
req_file.write(DEFAULT_REQUIREMENTS)
|
||||
req_file.close()
|
||||
req_path = req_file.name
|
||||
|
||||
secrets = {}
|
||||
hf_token = os.environ.get("HF_TOKEN")
|
||||
if hf_token:
|
||||
secrets["HF_TOKEN"] = hf_token
|
||||
|
||||
deploy_name = self.deploy
|
||||
deploy_org = None
|
||||
if "/" in deploy_name:
|
||||
deploy_org, deploy_name = deploy_name.rsplit("/", 1)
|
||||
|
||||
_deploy(
|
||||
script_path=Path(tmp.name),
|
||||
name=deploy_name,
|
||||
title=None,
|
||||
org=deploy_org,
|
||||
private=self.private,
|
||||
hardware=self.hardware,
|
||||
secrets=secrets,
|
||||
requirements_path=req_path,
|
||||
dry_run=False,
|
||||
)
|
||||
|
||||
def _get_workflow_resolve_code(self):
|
||||
if self.workflow:
|
||||
return f"_pipeline._blocks.get_workflow({self.workflow!r})"
|
||||
elif self.trigger_inputs:
|
||||
kwargs_str = ", ".join(f"{name!r}: True" for name in self.trigger_inputs)
|
||||
return f"_pipeline._blocks.get_execution_blocks(**{{{kwargs_str}}})"
|
||||
else:
|
||||
return "_pipeline._blocks.get_execution_blocks()"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Block analysis
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_block_info(exec_blocks):
|
||||
block_infos = []
|
||||
for name, block in exec_blocks.sub_blocks.items():
|
||||
block_infos.append(
|
||||
BlockInfo(
|
||||
name=name,
|
||||
class_name=block.__class__.__name__,
|
||||
description=getattr(block, "description", "") or "",
|
||||
inputs=list(block.inputs) if hasattr(block, "inputs") else [],
|
||||
outputs=list(block.intermediate_outputs) if hasattr(block, "intermediate_outputs") else [],
|
||||
)
|
||||
)
|
||||
return block_infos
|
||||
|
||||
|
||||
def _filter_outputs(block_infos):
|
||||
last_idx = len(block_infos) - 1
|
||||
for i, info in enumerate(block_infos):
|
||||
downstream_input_names = set()
|
||||
for later_info in block_infos[i + 1 :]:
|
||||
for inp in later_info.inputs:
|
||||
if inp.name:
|
||||
downstream_input_names.add(inp.name)
|
||||
|
||||
is_last = i == last_idx
|
||||
info.outputs = [
|
||||
out
|
||||
for out in info.outputs
|
||||
if out.name in downstream_input_names
|
||||
or (is_last and (_contains_pil_image(out.type_hint) or not _is_internal_type(out.type_hint)))
|
||||
]
|
||||
|
||||
|
||||
def _classify_inputs(block_infos):
|
||||
all_prior_outputs = {}
|
||||
|
||||
for info in block_infos:
|
||||
user_inputs, port_connections, fixed_inputs = [], [], []
|
||||
|
||||
for inp in info.inputs:
|
||||
if inp.name is None:
|
||||
continue
|
||||
if inp.name in all_prior_outputs:
|
||||
port_connections.append((inp.name, all_prior_outputs[inp.name]))
|
||||
elif _is_user_facing(inp):
|
||||
user_inputs.append(inp)
|
||||
else:
|
||||
fixed_inputs.append(inp)
|
||||
|
||||
info.user_inputs = user_inputs
|
||||
info.port_connections = port_connections
|
||||
info.fixed_inputs = fixed_inputs
|
||||
|
||||
for out in info.outputs:
|
||||
if out.name:
|
||||
all_prior_outputs[out.name] = info.name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_type_name(type_hint):
|
||||
if type_hint is None:
|
||||
return None
|
||||
if hasattr(type_hint, "__name__"):
|
||||
return type_hint.__name__
|
||||
if hasattr(type_hint, "__module__") and hasattr(type_hint, "__qualname__"):
|
||||
return f"{type_hint.__module__}.{type_hint.__qualname__}"
|
||||
return str(type_hint)
|
||||
|
||||
|
||||
def _is_internal_type(type_hint):
|
||||
if type_hint is None:
|
||||
return False
|
||||
type_name = _get_type_name(type_hint)
|
||||
if type_name is None:
|
||||
return False
|
||||
if type_name in INTERNAL_TYPE_NAMES or type_name in INTERNAL_TYPE_FULL_NAMES:
|
||||
return True
|
||||
type_str = str(type_hint)
|
||||
for full_name in INTERNAL_TYPE_FULL_NAMES:
|
||||
if full_name in type_str:
|
||||
return True
|
||||
if type_str.startswith("dict[") or type_str == "dict":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _contains_pil_image(type_hint):
|
||||
from PIL import Image
|
||||
|
||||
if type_hint is Image.Image:
|
||||
return True
|
||||
args = typing.get_args(type_hint)
|
||||
return any(_contains_pil_image(a) for a in args) if args else False
|
||||
|
||||
|
||||
def _is_list_or_tuple_type(type_hint):
|
||||
origin = typing.get_origin(type_hint)
|
||||
if origin in (list, tuple):
|
||||
return True
|
||||
if origin is typing.Union or origin is types.UnionType:
|
||||
return any(_is_list_or_tuple_type(a) for a in typing.get_args(type_hint))
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_from_template(inp):
|
||||
from ..modular_pipelines.modular_pipeline_utils import INPUT_PARAM_TEMPLATES
|
||||
|
||||
type_hint = inp.type_hint
|
||||
default = inp.default
|
||||
if inp.name in INPUT_PARAM_TEMPLATES:
|
||||
tmpl = INPUT_PARAM_TEMPLATES[inp.name]
|
||||
if type_hint is None:
|
||||
type_hint = tmpl.get("type_hint")
|
||||
if default is None:
|
||||
default = tmpl.get("default")
|
||||
return type_hint, default
|
||||
|
||||
|
||||
def _is_user_facing(inp):
|
||||
type_hint, _ = _resolve_from_template(inp)
|
||||
if type_hint is not None:
|
||||
if _contains_pil_image(type_hint):
|
||||
return True
|
||||
return not _is_internal_type(type_hint)
|
||||
if inp.name in SLIDER_PARAMS:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _sanitize_name(name):
|
||||
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
if sanitized and sanitized[0].isdigit():
|
||||
sanitized = f"_{sanitized}"
|
||||
return sanitized
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gradio component code strings
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _type_hint_to_gradio(type_hint, param_name, default=None):
|
||||
if param_name in DROPDOWN_PARAMS:
|
||||
opts = DROPDOWN_PARAMS[param_name]
|
||||
val = default if default is not None else opts.get("value")
|
||||
return f'gr.Dropdown(label="{param_name}", choices={opts["choices"]!r}, value={val!r})'
|
||||
|
||||
if param_name in SLIDER_PARAMS:
|
||||
opts = SLIDER_PARAMS[param_name]
|
||||
val = default if default is not None else opts.get("minimum", 0)
|
||||
return (
|
||||
f'gr.Slider(label="{param_name}", value={val!r}, '
|
||||
f"minimum={opts['minimum']}, maximum={opts['maximum']}, step={opts['step']})"
|
||||
)
|
||||
|
||||
if type_hint is not None and _is_internal_type(type_hint):
|
||||
return None
|
||||
if type_hint is not None and _contains_pil_image(type_hint):
|
||||
return f'gr.Image(label="{param_name}")'
|
||||
if type_hint is str:
|
||||
default_repr = f", value={default!r}" if default is not None else ""
|
||||
return f'gr.Textbox(label="{param_name}", lines=1{default_repr})'
|
||||
if type_hint is int:
|
||||
val = f", value={default!r}" if default is not None else ""
|
||||
return f'gr.Number(label="{param_name}", precision=0{val})'
|
||||
if type_hint is float:
|
||||
val = f", value={default!r}" if default is not None else ""
|
||||
return f'gr.Number(label="{param_name}"{val})'
|
||||
if type_hint is bool:
|
||||
val = default if default is not None else False
|
||||
return f'gr.Checkbox(label="{param_name}", value={val!r})'
|
||||
if default is not None:
|
||||
return f'gr.Textbox(label="{param_name}", value={default!r})'
|
||||
return f'gr.Textbox(label="{param_name}")'
|
||||
|
||||
|
||||
def _output_type_to_gradio(type_hint, param_name):
|
||||
if type_hint is not None and _contains_pil_image(type_hint):
|
||||
return f'gr.Image(label="{param_name}")'
|
||||
if type_hint is str:
|
||||
return f'gr.Textbox(label="{param_name}")'
|
||||
if type_hint is int or type_hint is float:
|
||||
return f'gr.Number(label="{param_name}")'
|
||||
return f'gr.Textbox(label="{param_name}", visible=False)'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_code(block_infos, repo_id, blocks_class_name, workflow_label, workflow_resolve_code):
|
||||
sections = []
|
||||
|
||||
# Pre-compute metadata
|
||||
blocks_with_image_inputs = {}
|
||||
blocks_with_parseable_inputs = {}
|
||||
blocks_with_image_outputs = set()
|
||||
port_conn_source = {}
|
||||
|
||||
for info in block_infos:
|
||||
img_names, parse_names = set(), set()
|
||||
for inp in info.inputs:
|
||||
if inp.name is None:
|
||||
continue
|
||||
resolved_type, _ = _resolve_from_template(inp)
|
||||
if resolved_type is not None and _contains_pil_image(resolved_type):
|
||||
img_names.add(inp.name)
|
||||
elif resolved_type is not None and _is_list_or_tuple_type(resolved_type):
|
||||
parse_names.add(inp.name)
|
||||
if img_names:
|
||||
blocks_with_image_inputs[info.name] = img_names
|
||||
if parse_names:
|
||||
blocks_with_parseable_inputs[info.name] = parse_names
|
||||
if any(out.type_hint is not None and _contains_pil_image(out.type_hint) for out in info.outputs):
|
||||
blocks_with_image_outputs.add(info.name)
|
||||
for conn_name, source_block in info.port_connections:
|
||||
port_conn_source[(info.name, conn_name)] = source_block
|
||||
|
||||
needs_image_io = blocks_with_image_inputs or blocks_with_image_outputs
|
||||
needs_parsing = bool(blocks_with_parseable_inputs)
|
||||
|
||||
# -- Header & helpers --
|
||||
extra_imports = ""
|
||||
if needs_parsing:
|
||||
extra_imports += "import ast\n"
|
||||
if needs_image_io:
|
||||
extra_imports += "import tempfile\n"
|
||||
|
||||
sections.append(f'"""Daggr app for {blocks_class_name} ({workflow_label} workflow)')
|
||||
sections.append("Generated by: diffusers-cli daggr")
|
||||
sections.append('"""')
|
||||
header = _HEADER_TEMPLATE.format(
|
||||
extra_imports=extra_imports, repo_id=repr(repo_id), workflow_resolve_code=workflow_resolve_code
|
||||
)
|
||||
sections.extend(header.splitlines())
|
||||
|
||||
if needs_image_io:
|
||||
sections.extend(_SAVE_IMAGE_TEMPLATE.splitlines())
|
||||
|
||||
# -- Block functions --
|
||||
for info in block_infos:
|
||||
fn_name = f"run_{_sanitize_name(info.name)}"
|
||||
input_names = [inp.name for inp in info.inputs if inp.name is not None]
|
||||
port_conn_names = {c for c, _ in info.port_connections}
|
||||
has_image_out = info.name in blocks_with_image_outputs
|
||||
|
||||
body_lines = []
|
||||
body_lines.append(" from diffusers.modular_pipelines.modular_pipeline import PipelineState")
|
||||
body_lines.append("")
|
||||
body_lines.append(" pipe, exec_blocks = _get_pipeline()")
|
||||
body_lines.append(" state = PipelineState()")
|
||||
|
||||
if info.name in blocks_with_image_inputs:
|
||||
body_lines.append(" from PIL import Image as PILImage")
|
||||
body_lines.append("")
|
||||
for img_name in blocks_with_image_inputs[info.name]:
|
||||
body_lines.append(f" if {img_name} is not None and isinstance({img_name}, str):")
|
||||
body_lines.append(f" {img_name} = PILImage.open({img_name})")
|
||||
|
||||
if info.name in blocks_with_parseable_inputs:
|
||||
for parse_name in blocks_with_parseable_inputs[info.name]:
|
||||
body_lines.append(f" if {parse_name} is not None and isinstance({parse_name}, str):")
|
||||
body_lines.append(
|
||||
f" {parse_name} = ast.literal_eval({parse_name}.strip()) if {parse_name}.strip() else None"
|
||||
)
|
||||
|
||||
for n in input_names:
|
||||
if n in port_conn_names:
|
||||
source = port_conn_source[(info.name, n)]
|
||||
body_lines.append(f' state.set("{n}", _tensor_store.get("{source}:{n}", {n}))')
|
||||
else:
|
||||
body_lines.append(f' state.set("{n}", {n})')
|
||||
|
||||
if info.sub_block_names:
|
||||
for n in info.sub_block_names:
|
||||
body_lines.append(f' _, state = exec_blocks.sub_blocks["{n}"](pipe, state)')
|
||||
else:
|
||||
body_lines.append(f' _, state = exec_blocks.sub_blocks["{info.name}"](pipe, state)')
|
||||
|
||||
for o in info.outputs:
|
||||
if o.type_hint is not None and _is_internal_type(o.type_hint):
|
||||
body_lines.append(f' _tensor_store["{info.name}:{o.name}"] = state.get("{o.name}")')
|
||||
|
||||
if len(info.outputs) == 0:
|
||||
body_lines.append(" return None")
|
||||
elif len(info.outputs) == 1:
|
||||
out = info.outputs[0]
|
||||
if has_image_out and _contains_pil_image(out.type_hint):
|
||||
body_lines.append(f' return _save_image(state.get("{out.name}"))')
|
||||
elif out.type_hint is not None and _is_internal_type(out.type_hint):
|
||||
body_lines.append(f' return "{info.name}:{out.name}"')
|
||||
else:
|
||||
body_lines.append(f' return state.get("{out.name}")')
|
||||
else:
|
||||
parts = []
|
||||
for o in info.outputs:
|
||||
if has_image_out and o.type_hint is not None and _contains_pil_image(o.type_hint):
|
||||
parts.append(f'_save_image(state.get("{o.name}"))')
|
||||
elif o.type_hint is not None and _is_internal_type(o.type_hint):
|
||||
parts.append(f'"{info.name}:{o.name}"')
|
||||
else:
|
||||
parts.append(f'state.get("{o.name}")')
|
||||
body_lines.append(f" return {', '.join(parts)}")
|
||||
|
||||
body = "\n".join(body_lines)
|
||||
sections.append(f"\n\ndef {fn_name}({', '.join(input_names)}):\n{body}\n")
|
||||
|
||||
# -- Node definitions --
|
||||
sections.append("\n\n# -- Pipeline Blocks --")
|
||||
|
||||
node_var_names = {}
|
||||
input_node_var_names = []
|
||||
user_input_sources = {}
|
||||
|
||||
for info in block_infos:
|
||||
var_name = f"{_sanitize_name(info.name)}_node"
|
||||
node_var_names[info.name] = var_name
|
||||
fn_name = f"run_{_sanitize_name(info.name)}"
|
||||
display_name = info.name.replace("_", " ").replace(".", " > ").title()
|
||||
|
||||
for inp in info.user_inputs:
|
||||
if inp.name in user_input_sources:
|
||||
continue
|
||||
resolved_type, resolved_default = _resolve_from_template(inp)
|
||||
gradio_comp = _type_hint_to_gradio(resolved_type, inp.name, resolved_default)
|
||||
if gradio_comp:
|
||||
input_var = f"{_sanitize_name(inp.name)}_input"
|
||||
sections.append(
|
||||
f'\n{input_var} = InputNode("{inp.name}", ports={{\n "{inp.name}": {gradio_comp},\n}})'
|
||||
)
|
||||
input_node_var_names.append(input_var)
|
||||
user_input_sources[inp.name] = input_var
|
||||
|
||||
input_entries = []
|
||||
for inp in info.inputs:
|
||||
if inp.name is None:
|
||||
continue
|
||||
connected = False
|
||||
for conn_name, source_block in info.port_connections:
|
||||
if conn_name == inp.name:
|
||||
input_entries.append(f' "{inp.name}": {node_var_names[source_block]}.{inp.name},')
|
||||
connected = True
|
||||
break
|
||||
if not connected:
|
||||
if inp.name in user_input_sources:
|
||||
input_entries.append(f' "{inp.name}": {user_input_sources[inp.name]}.{inp.name},')
|
||||
elif inp.default is not None:
|
||||
input_entries.append(f' "{inp.name}": {inp.default!r},')
|
||||
else:
|
||||
input_entries.append(f' "{inp.name}": None,')
|
||||
|
||||
output_entries = []
|
||||
for out in info.outputs:
|
||||
output_entries.append(f' "{out.name}": {_output_type_to_gradio(out.type_hint, out.name)},')
|
||||
|
||||
node_parts = [f"\n{var_name} = FnNode(", f" fn={fn_name},", f' name="{display_name}",']
|
||||
if input_entries:
|
||||
node_parts.append(" inputs={")
|
||||
node_parts.extend(input_entries)
|
||||
node_parts.append(" },")
|
||||
if output_entries:
|
||||
node_parts.append(" outputs={")
|
||||
node_parts.extend(output_entries)
|
||||
node_parts.append(" },")
|
||||
node_parts.append(")")
|
||||
sections.append("\n".join(node_parts))
|
||||
|
||||
# -- Graph launch with cache disabled --
|
||||
all_node_vars = input_node_var_names + [node_var_names[info.name] for info in block_infos]
|
||||
graph_name = f"{blocks_class_name} - {workflow_label}"
|
||||
nodes_str = ", ".join(all_node_vars)
|
||||
|
||||
sections.append("")
|
||||
sections.append("")
|
||||
sections.append("# -- Graph --")
|
||||
sections.extend(_CACHE_DISABLE_TEMPLATE.splitlines())
|
||||
sections.append("")
|
||||
sections.append(f'graph = Graph("{graph_name}", nodes=[{nodes_str}])')
|
||||
sections.append("graph.launch()")
|
||||
sections.append("")
|
||||
|
||||
return "\n".join(sections)
|
||||
@@ -16,6 +16,7 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from .custom_blocks import CustomBlocksCommand
|
||||
from .daggr_app import DaggrCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .fp16_safetensors import FP16SafetensorsCommand
|
||||
|
||||
@@ -28,6 +29,7 @@ def main():
|
||||
EnvironmentCommand.register_subcommand(commands_parser)
|
||||
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
||||
CustomBlocksCommand.register_subcommand(commands_parser)
|
||||
DaggrCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user