mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-03 07:10:34 +08:00
Compare commits
5 Commits
update-mod
...
dataclass-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c99566bab | ||
|
|
39188248a7 | ||
|
|
9b97932424 | ||
|
|
680076fcc0 | ||
|
|
5910a1cc6c |
@@ -14,4 +14,8 @@
|
||||
|
||||
## AutoPipelineBlocks
|
||||
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
|
||||
|
||||
## ConditionalPipelineBlocks
|
||||
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks
|
||||
@@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# List of sub-block classes to choose from
|
||||
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
# Names for each block in the same order
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
# Trigger inputs that determine which block to run
|
||||
@@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# - "image" triggers img2img workflow (but only if mask is not provided)
|
||||
# - if none of above, runs the text2img workflow (default)
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
# Description is extremely important for AutoPipelineBlocks
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
@@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks):
|
||||
)
|
||||
```
|
||||
|
||||
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
|
||||
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.
|
||||
|
||||
Create an instance of `AutoImageBlocks`.
|
||||
|
||||
@@ -152,5 +152,74 @@ auto_blocks = AutoImageBlocks()
|
||||
For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
|
||||
|
||||
```py
|
||||
auto_blocks.get_execution_blocks("mask")
|
||||
auto_blocks.get_execution_blocks(mask=True)
|
||||
```
|
||||
|
||||
## ConditionalPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.
|
||||
|
||||
Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ConditionalPipelineBlocks
|
||||
|
||||
class AutoImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
|
||||
+ " - inpaint workflow is run when `mask` is provided.\n"
|
||||
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
|
||||
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
|
||||
)
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name ("text2img")
|
||||
```
|
||||
|
||||
The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.
|
||||
|
||||
## Workflows
|
||||
|
||||
Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.
|
||||
|
||||
We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class MyPipelineBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]
|
||||
block_names = ["text_encoder", "auto_image", "decode"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"image": True, "prompt": True},
|
||||
"inpaint": {"mask": True, "image": True, "prompt": True},
|
||||
}
|
||||
```
|
||||
|
||||
All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:
|
||||
|
||||
```py
|
||||
pipeline_blocks = MyPipelineBlocks()
|
||||
pipeline_blocks.available_workflows
|
||||
# ['text2image', 'image2image', 'inpaint']
|
||||
```
|
||||
|
||||
Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.
|
||||
|
||||
```py
|
||||
pipeline_blocks.get_workflow("inpaint")
|
||||
```
|
||||
@@ -648,6 +648,28 @@ class ConfigMixin:
|
||||
)
|
||||
return config_file
|
||||
|
||||
@classmethod
|
||||
def _get_dataclass_from_config(cls, config_dict: dict[str, Any]):
|
||||
sig = inspect.signature(cls.__init__)
|
||||
fields = []
|
||||
for name, param in sig.parameters.items():
|
||||
if name == "self" or name == "kwargs" or name in cls.ignore_for_config:
|
||||
continue
|
||||
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
|
||||
if param.default is not inspect.Parameter.empty:
|
||||
fields.append((name, annotation, dataclasses.field(default=param.default)))
|
||||
else:
|
||||
fields.append((name, annotation))
|
||||
|
||||
dc_cls = dataclasses.make_dataclass(
|
||||
f"{cls.__name__}Config",
|
||||
fields,
|
||||
frozen=True,
|
||||
)
|
||||
valid_fields = {f.name for f in dataclasses.fields(dc_cls)}
|
||||
init_kwargs = {k: v for k, v in config_dict.items() if k in valid_fields}
|
||||
return dc_cls(**init_kwargs)
|
||||
|
||||
|
||||
def register_to_config(init):
|
||||
r"""
|
||||
|
||||
@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
|
||||
if has_diffb:
|
||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||
if zero_status_diff_b:
|
||||
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
state_dict = {
|
||||
_custom_replace(k, limit_substrings): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith(("lora_unet_", "lora_te_"))
|
||||
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
|
||||
}
|
||||
|
||||
if any("text_projection" in k for k in state_dict):
|
||||
|
||||
@@ -1633,7 +1633,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name, None)
|
||||
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
|
||||
# fall back to default_blocks_name
|
||||
if blocks_class is None or not blocks_class.block_classes:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
|
||||
if blocks_class is not None:
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
@@ -1836,7 +1843,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
update_model_card = kwargs.pop("update_model_card", False)
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Generate modular pipeline card content
|
||||
@@ -1849,7 +1855,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
is_pipeline=True,
|
||||
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
|
||||
is_modular=True,
|
||||
update_model_card=update_model_card,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
|
||||
|
||||
@@ -50,7 +50,11 @@ This modular pipeline is composed of the following blocks:
|
||||
|
||||
{components_description} {configs_section}
|
||||
|
||||
{io_specification_section}
|
||||
## Input/Output Specification
|
||||
|
||||
### Inputs {inputs_description}
|
||||
|
||||
### Outputs {outputs_description}
|
||||
"""
|
||||
|
||||
|
||||
@@ -795,46 +799,6 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
|
||||
return format_params(output_params, "Outputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_params_markdown(params, header="Inputs"):
|
||||
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
|
||||
|
||||
Suitable for model cards rendered on Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
params: list of InputParam or OutputParam objects to format
|
||||
header: Header text (e.g. "Inputs" or "Outputs")
|
||||
|
||||
Returns:
|
||||
A formatted markdown string, or empty string if params is empty.
|
||||
"""
|
||||
if not params:
|
||||
return ""
|
||||
|
||||
def get_type_str(type_hint):
|
||||
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
|
||||
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
|
||||
return " | ".join(type_strs)
|
||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||
|
||||
lines = [f"**{header}:**\n"] if header else []
|
||||
for param in params:
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
|
||||
param_str = f"- `{name}` (`{type_str}`"
|
||||
|
||||
if hasattr(param, "required") and not param.required:
|
||||
param_str += ", *optional*"
|
||||
if param.default is not None:
|
||||
param_str += f", defaults to `{param.default}`"
|
||||
param_str += ")"
|
||||
|
||||
desc = param.description if param.description else "No description provided"
|
||||
param_str += f": {desc}"
|
||||
lines.append(param_str)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ComponentSpec objects into a readable string representation.
|
||||
|
||||
@@ -1091,7 +1055,8 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
- blocks_description: Detailed architecture of blocks
|
||||
- components_description: List of required components
|
||||
- configs_section: Configuration parameters section
|
||||
- io_specification_section: Input/Output specification (per-workflow or unified)
|
||||
- inputs_description: Input parameters specification
|
||||
- outputs_description: Output parameters specification
|
||||
- trigger_inputs_section: Conditional execution information
|
||||
- tags: List of relevant tags for the model card
|
||||
"""
|
||||
@@ -1110,6 +1075,15 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
if block_desc:
|
||||
blocks_desc_parts.append(f" - {block_desc}")
|
||||
|
||||
# add sub-blocks if any
|
||||
if hasattr(block, "sub_blocks") and block.sub_blocks:
|
||||
for sub_name, sub_block in block.sub_blocks.items():
|
||||
sub_class = sub_block.__class__.__name__
|
||||
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
|
||||
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
|
||||
if sub_desc:
|
||||
blocks_desc_parts.append(f" - {sub_desc}")
|
||||
|
||||
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
|
||||
|
||||
components = getattr(blocks, "expected_components", [])
|
||||
@@ -1135,76 +1109,63 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
if configs_description:
|
||||
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
|
||||
|
||||
# Branch on whether workflows are defined
|
||||
has_workflows = getattr(blocks, "_workflow_map", None) is not None
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
|
||||
if has_workflows:
|
||||
workflow_map = blocks._workflow_map
|
||||
parts = []
|
||||
# format inputs as markdown list
|
||||
inputs_parts = []
|
||||
required_inputs = [inp for inp in inputs if inp.required]
|
||||
optional_inputs = [inp for inp in inputs if not inp.required]
|
||||
|
||||
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
|
||||
# use that as the shared output for all workflows
|
||||
blocks_outputs = blocks.outputs
|
||||
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
|
||||
shared_outputs = (
|
||||
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
|
||||
)
|
||||
if required_inputs:
|
||||
inputs_parts.append("**Required:**\n")
|
||||
for inp in required_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
|
||||
|
||||
parts.append("## Workflow Input Specification\n")
|
||||
if optional_inputs:
|
||||
if required_inputs:
|
||||
inputs_parts.append("")
|
||||
inputs_parts.append("**Optional:**\n")
|
||||
for inp in optional_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
|
||||
|
||||
# Per-workflow details: show trigger inputs with full param descriptions
|
||||
for wf_name, trigger_inputs in workflow_map.items():
|
||||
trigger_input_names = set(trigger_inputs.keys())
|
||||
try:
|
||||
workflow_blocks = blocks.get_workflow(wf_name)
|
||||
except Exception:
|
||||
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
|
||||
parts.append("*Could not resolve workflow blocks.*\n")
|
||||
parts.append("</details>\n")
|
||||
continue
|
||||
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
|
||||
|
||||
wf_inputs = workflow_blocks.inputs
|
||||
# Show only trigger inputs with full parameter descriptions
|
||||
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
|
||||
# format outputs as markdown list
|
||||
outputs_parts = []
|
||||
for out in outputs:
|
||||
if hasattr(out.type_hint, "__name__"):
|
||||
type_str = out.type_hint.__name__
|
||||
elif out.type_hint is not None:
|
||||
type_str = str(out.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = out.description or "No description provided"
|
||||
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
|
||||
|
||||
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
|
||||
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
|
||||
|
||||
inputs_str = format_params_markdown(trigger_params, header=None)
|
||||
parts.append(inputs_str if inputs_str else "No additional inputs required.")
|
||||
parts.append("")
|
||||
|
||||
parts.append("</details>\n")
|
||||
|
||||
# Common Inputs & Outputs section (like non-workflow pipelines)
|
||||
all_inputs = blocks.inputs
|
||||
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
|
||||
|
||||
inputs_str = format_params_markdown(all_inputs, "Inputs")
|
||||
outputs_str = format_params_markdown(all_outputs, "Outputs")
|
||||
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
|
||||
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
|
||||
|
||||
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
|
||||
|
||||
io_specification_section = "\n".join(parts)
|
||||
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
|
||||
trigger_inputs_section = ""
|
||||
else:
|
||||
# Unified I/O section (original behavior)
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
inputs_str = format_params_markdown(inputs, "Inputs")
|
||||
outputs_str = format_params_markdown(outputs, "Outputs")
|
||||
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
|
||||
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
|
||||
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
|
||||
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
### Conditional Execution
|
||||
|
||||
This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
@@ -1217,18 +1178,7 @@ This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
if hasattr(blocks, "model_name") and blocks.model_name:
|
||||
tags.append(blocks.model_name)
|
||||
|
||||
if has_workflows:
|
||||
# Derive tags from workflow names
|
||||
workflow_names = set(blocks._workflow_map.keys())
|
||||
if any("inpainting" in wf for wf in workflow_names):
|
||||
tags.append("inpainting")
|
||||
if any("image2image" in wf for wf in workflow_names):
|
||||
tags.append("image-to-image")
|
||||
if any("controlnet" in wf for wf in workflow_names):
|
||||
tags.append("controlnet")
|
||||
if any("text2image" in wf for wf in workflow_names):
|
||||
tags.append("text-to-image")
|
||||
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
triggers = blocks.trigger_inputs
|
||||
if any(t in triggers for t in ["mask", "mask_image"]):
|
||||
tags.append("inpainting")
|
||||
@@ -1256,7 +1206,8 @@ This pipeline uses a {block_count}-block architecture that can be customized and
|
||||
"blocks_description": blocks_description,
|
||||
"components_description": components_description,
|
||||
"configs_section": configs_section,
|
||||
"io_specification_section": io_specification_section,
|
||||
"inputs_description": inputs_description,
|
||||
"outputs_description": outputs_description,
|
||||
"trigger_inputs_section": trigger_inputs_section,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@@ -107,7 +107,6 @@ def load_or_create_model_card(
|
||||
widget: list[dict] | None = None,
|
||||
inference: bool | None = None,
|
||||
is_modular: bool = False,
|
||||
update_model_card: bool = False,
|
||||
) -> ModelCard:
|
||||
"""
|
||||
Loads or creates a model card.
|
||||
@@ -134,9 +133,6 @@ def load_or_create_model_card(
|
||||
`load_or_create_model_card` from a training script.
|
||||
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
|
||||
When True, uses model_description as-is without additional template formatting.
|
||||
update_model_card: (`bool`, optional): When True, regenerates the model card content even if one
|
||||
already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only
|
||||
supported for modular pipelines (i.e., `is_modular=True`).
|
||||
"""
|
||||
if not is_jinja_available():
|
||||
raise ValueError(
|
||||
@@ -145,17 +141,9 @@ def load_or_create_model_card(
|
||||
" To install it, please run `pip install Jinja2`."
|
||||
)
|
||||
|
||||
if update_model_card and not is_modular:
|
||||
raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).")
|
||||
|
||||
try:
|
||||
# Check if the model card is present on the remote repo
|
||||
model_card = ModelCard.load(repo_id_or_path, token=token)
|
||||
# For modular pipelines, regenerate card content when requested (preserve existing metadata)
|
||||
if update_model_card and is_modular and model_description is not None:
|
||||
existing_data = model_card.data
|
||||
model_card = ModelCard(model_description)
|
||||
model_card.data = existing_data
|
||||
except (EntryNotFoundError, RepositoryNotFoundError):
|
||||
# Otherwise create a model card from template
|
||||
if from_training:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Callable
|
||||
|
||||
@@ -349,6 +351,33 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_modular_index_consistency(self):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
components = sorted(components_spec.keys())
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
pipe.save_pretrained(tmpdir)
|
||||
index_file = os.path.join(tmpdir, "modular_model_index.json")
|
||||
assert os.path.exists(index_file)
|
||||
|
||||
with open(index_file) as f:
|
||||
index_contents = json.load(f)
|
||||
|
||||
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
|
||||
for k in compulsory_keys:
|
||||
assert k in index_contents
|
||||
|
||||
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
|
||||
for component in components:
|
||||
spec = components_spec[component]
|
||||
for attr in to_check_attrs:
|
||||
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
|
||||
for attr in to_check_attrs:
|
||||
assert component in index_contents, f"{component} should be present in index but isn't."
|
||||
attr_value_from_index = index_contents[component][2][attr]
|
||||
assert getattr(spec, attr) == attr_value_from_index
|
||||
|
||||
def test_workflow_map(self):
|
||||
blocks = self.pipeline_blocks_class()
|
||||
if blocks._workflow_map is None:
|
||||
@@ -454,7 +483,8 @@ class TestModularModelCardContent:
|
||||
"blocks_description",
|
||||
"components_description",
|
||||
"configs_section",
|
||||
"io_specification_section",
|
||||
"inputs_description",
|
||||
"outputs_description",
|
||||
"trigger_inputs_section",
|
||||
"tags",
|
||||
]
|
||||
@@ -551,19 +581,18 @@ class TestModularModelCardContent:
|
||||
blocks = self.create_mock_blocks(inputs=inputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
io_section = content["io_specification_section"]
|
||||
assert "**Inputs:**" in io_section
|
||||
assert "prompt" in io_section
|
||||
assert "num_steps" in io_section
|
||||
assert "*optional*" in io_section
|
||||
assert "defaults to `50`" in io_section
|
||||
assert "**Required:**" in content["inputs_description"]
|
||||
assert "**Optional:**" in content["inputs_description"]
|
||||
assert "prompt" in content["inputs_description"]
|
||||
assert "num_steps" in content["inputs_description"]
|
||||
assert "default: `50`" in content["inputs_description"]
|
||||
|
||||
def test_inputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific inputs."""
|
||||
blocks = self.create_mock_blocks(inputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "No specific inputs defined" in content["io_specification_section"]
|
||||
assert "No specific inputs defined" in content["inputs_description"]
|
||||
|
||||
def test_outputs_description_formatting(self):
|
||||
"""Test that outputs are correctly formatted."""
|
||||
@@ -573,16 +602,15 @@ class TestModularModelCardContent:
|
||||
blocks = self.create_mock_blocks(outputs=outputs)
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
io_section = content["io_specification_section"]
|
||||
assert "images" in io_section
|
||||
assert "Generated images" in io_section
|
||||
assert "images" in content["outputs_description"]
|
||||
assert "Generated images" in content["outputs_description"]
|
||||
|
||||
def test_outputs_description_empty(self):
|
||||
"""Test handling of pipelines without specific outputs."""
|
||||
blocks = self.create_mock_blocks(outputs=[])
|
||||
content = generate_modular_model_card_content(blocks)
|
||||
|
||||
assert "Standard pipeline outputs" in content["io_specification_section"]
|
||||
assert "Standard pipeline outputs" in content["outputs_description"]
|
||||
|
||||
def test_trigger_inputs_section_with_triggers(self):
|
||||
"""Test that trigger inputs section is generated when present."""
|
||||
@@ -700,3 +728,27 @@ class TestLoadComponentsSkipBehavior:
|
||||
|
||||
# Verify test_component was not loaded
|
||||
assert not hasattr(pipe, "test_component") or pipe.test_component is None
|
||||
|
||||
|
||||
class TestModularPipelineInitFallback:
|
||||
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
|
||||
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""
|
||||
|
||||
def test_init_fallback_when_blocks_class_name_is_base_class(self, tmp_path):
|
||||
# 1. Load pipeline and get a workflow (returns a base SequentialPipelineBlocks)
|
||||
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
t2i_blocks = pipe.blocks.get_workflow("text2image")
|
||||
assert t2i_blocks.__class__.__name__ == "SequentialPipelineBlocks"
|
||||
|
||||
# 2. Use init_pipeline to create a new pipeline from the workflow blocks
|
||||
t2i_pipe = t2i_blocks.init_pipeline("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
|
||||
|
||||
# 3. Save and reload — the saved config will have _blocks_class_name="SequentialPipelineBlocks"
|
||||
save_dir = str(tmp_path / "pipeline")
|
||||
t2i_pipe.save_pretrained(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
|
||||
# 4. Verify it fell back to default_blocks_name and has correct blocks
|
||||
assert loaded_pipe.__class__.__name__ == pipe.__class__.__name__
|
||||
assert loaded_pipe._blocks.__class__.__name__ == pipe._blocks.__class__.__name__
|
||||
assert len(loaded_pipe._blocks.sub_blocks) == len(pipe._blocks.sub_blocks)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
@@ -305,3 +306,96 @@ class ConfigTester(unittest.TestCase):
|
||||
result = json.loads(json_string)
|
||||
assert result["test_file_1"] == config.config.test_file_1.as_posix()
|
||||
assert result["test_file_2"] == config.config.test_file_2.as_posix()
|
||||
|
||||
|
||||
class SampleObjectTyped(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a: int = 2,
|
||||
b: int = 5,
|
||||
c: str = "hello",
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class SampleObjectWithIgnore(ConfigMixin):
|
||||
config_name = "config.json"
|
||||
ignore_for_config = ["secret"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
a: int = 2,
|
||||
secret: str = "hidden",
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class DataclassFromConfigTester(unittest.TestCase):
|
||||
def test_get_dataclass_from_config_returns_frozen_dataclass(self):
|
||||
obj = SampleObject()
|
||||
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||
assert dataclasses.is_dataclass(tc)
|
||||
with self.assertRaises(dataclasses.FrozenInstanceError):
|
||||
tc.a = 99
|
||||
|
||||
def test_get_dataclass_from_config_class_name(self):
|
||||
obj = SampleObject()
|
||||
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||
assert type(tc).__name__ == "SampleObjectConfig"
|
||||
|
||||
def test_get_dataclass_from_config_values_match_config(self):
|
||||
obj = SampleObject(a=10, b=20)
|
||||
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
|
||||
assert tc.a == 10
|
||||
assert tc.b == 20
|
||||
assert tc.c == (2, 5)
|
||||
assert tc.d == "for diffusion"
|
||||
assert tc.e == [1, 3]
|
||||
|
||||
def test_get_dataclass_from_config_from_raw_dict(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
|
||||
assert tc.a == 7
|
||||
assert tc.b == 3
|
||||
assert tc.c == "world"
|
||||
|
||||
def test_get_dataclass_from_config_annotations(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "hi"})
|
||||
fields = {f.name: f.type for f in dataclasses.fields(tc)}
|
||||
assert fields["a"] is int
|
||||
assert fields["b"] is int
|
||||
assert fields["c"] is str
|
||||
|
||||
def test_get_dataclass_from_config_asdict_roundtrip(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
|
||||
d = dataclasses.asdict(tc)
|
||||
assert d == {"a": 7, "b": 3, "c": "world"}
|
||||
|
||||
def test_get_dataclass_from_config_ignores_extra_keys(self):
|
||||
tc = SampleObjectTyped._get_dataclass_from_config(
|
||||
{"a": 1, "b": 2, "c": "hi", "_class_name": "Foo", "extra": 99}
|
||||
)
|
||||
assert tc.a == 1
|
||||
assert not hasattr(tc, "_class_name")
|
||||
assert not hasattr(tc, "extra")
|
||||
|
||||
def test_get_dataclass_from_config_respects_ignore_for_config(self):
|
||||
tc = SampleObjectWithIgnore._get_dataclass_from_config({"a": 5})
|
||||
assert not hasattr(tc, "secret")
|
||||
assert tc.a == 5
|
||||
|
||||
def test_get_dataclass_from_config_works_for_scheduler(self):
|
||||
scheduler = DDIMScheduler()
|
||||
tc = DDIMScheduler._get_dataclass_from_config(dict(scheduler.config))
|
||||
assert dataclasses.is_dataclass(tc)
|
||||
assert type(tc).__name__ == "DDIMSchedulerConfig"
|
||||
assert tc.num_train_timesteps == scheduler.config.num_train_timesteps
|
||||
|
||||
def test_get_dataclass_from_config_different_values(self):
|
||||
tc1 = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "x"})
|
||||
tc2 = SampleObjectTyped._get_dataclass_from_config({"a": 9, "b": 8, "c": "y"})
|
||||
assert tc1.a == 1
|
||||
assert tc2.a == 9
|
||||
|
||||
Reference in New Issue
Block a user