Compare commits

..

5 Commits

Author SHA1 Message Date
Sayak Paul
9e87e3d3be Merge branch 'main' into modular-index-tests 2026-02-27 15:46:35 +05:30
Christopher
5910a1cc6c Fixing Kohya loras loading: Flux.1-dev loras with TE ("lora_te1_" prefix) (#13188)
* fixing text encoder lora loading

* following Cursor's review
2026-02-27 15:43:41 +05:30
Sayak Paul
770a149173 Merge branch 'main' into modular-index-tests 2026-02-27 15:10:45 +05:30
sayakpaul
94457fd6b1 check for compulsory keys. 2026-02-27 15:02:17 +05:30
sayakpaul
6ebd990336 add a test to check modular index consistency 2026-02-27 14:59:58 +05:30
3 changed files with 104 additions and 130 deletions

View File

@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_"))
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
}
if any("text_projection" in k for k in state_dict):

View File

@@ -48,7 +48,13 @@ This modular pipeline is composed of the following blocks:
## Model Components
{components_description} {configs_section} {io_specification_section}
{components_description} {configs_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
"""
@@ -793,46 +799,6 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_params_markdown(params, header="Inputs"):
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
Suitable for model cards rendered on Hugging Face Hub.
Args:
params: list of InputParam or OutputParam objects to format
header: Header text (e.g. "Inputs" or "Outputs")
Returns:
A formatted markdown string, or empty string if params is empty.
"""
if not params:
return ""
def get_type_str(type_hint):
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
return " | ".join(type_strs)
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
lines = [f"**{header}:**\n"]
for param in params:
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"- `{name}` (`{type_str}`"
if hasattr(param, "required") and not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to `{param.default}`"
param_str += ")"
desc = param.description if param.description else "No description provided"
param_str += f": {desc}"
lines.append(param_str)
return "\n".join(lines)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
@@ -1089,7 +1055,8 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- io_specification_section: Input/Output specification (per-workflow or unified)
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
@@ -1142,74 +1109,63 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
# Branch on whether workflows are defined
has_workflows = getattr(blocks, "_workflow_map", None) is not None
inputs = blocks.inputs
outputs = blocks.outputs
if has_workflows:
# Per-workflow I/O sections
workflow_map = blocks._workflow_map
parts = []
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
# use that as the shared output for all workflows
blocks_outputs = blocks.outputs
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
shared_outputs = (
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
)
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
# Summary section using existing format_workflow
parts.append("## Supported Workflows\n")
parts.append(format_workflow(workflow_map))
parts.append("")
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
# Per-workflow details
for wf_name, trigger_inputs in workflow_map.items():
trigger_input_names = set(trigger_inputs.keys())
try:
workflow_blocks = blocks.get_workflow(wf_name)
except Exception:
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append(f"> **Trigger inputs**: {', '.join(f'`{t}`' for t in trigger_input_names)}\n")
parts.append("*Could not resolve workflow blocks.*\n")
parts.append("</details>\n")
continue
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
wf_inputs = workflow_blocks.inputs
wf_outputs = shared_outputs if shared_outputs is not None else workflow_blocks.outputs
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append(f"> **Trigger inputs**: {', '.join(f'`{t}`' for t in trigger_input_names)}\n")
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
inputs_str = format_params_markdown(wf_inputs, "Inputs")
parts.append(inputs_str if inputs_str else "No specific inputs defined.")
parts.append("")
outputs_str = format_params_markdown(wf_outputs, "Outputs")
parts.append(outputs_str if outputs_str else "No specific outputs defined.")
parts.append("")
parts.append("</details>\n")
io_specification_section = "\n".join(parts)
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
trigger_inputs_section = ""
else:
# Unified I/O section (original behavior)
inputs = blocks.inputs
outputs = blocks.outputs
inputs_str = format_params_markdown(inputs, "Inputs")
outputs_str = format_params_markdown(outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1222,18 +1178,7 @@ This pipeline contains blocks that are selected at runtime based on inputs:
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if has_workflows:
# Derive tags from workflow names
workflow_names = set(blocks._workflow_map.keys())
if any("inpainting" in wf for wf in workflow_names):
tags.append("inpainting")
if any("image2image" in wf for wf in workflow_names):
tags.append("image-to-image")
if any("controlnet" in wf for wf in workflow_names):
tags.append("controlnet")
if any("text2image" in wf for wf in workflow_names):
tags.append("text-to-image")
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
@@ -1261,7 +1206,8 @@ This pipeline uses a {block_count}-block architecture that can be customized and
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"io_specification_section": io_specification_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -1,4 +1,6 @@
import gc
import json
import os
import tempfile
from typing import Callable
@@ -349,6 +351,33 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_modular_index_consistency(self):
pipe = self.get_pipeline()
components_spec = pipe._component_specs
components = sorted(components_spec.keys())
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
index_file = os.path.join(tmpdir, "modular_model_index.json")
assert os.path.exists(index_file)
with open(index_file) as f:
index_contents = json.load(f)
compulsory_keys = {"_blocks_class_name", "_class_name", "_diffusers_version"}
for k in compulsory_keys:
assert k in index_contents
to_check_attrs = {"pretrained_model_name_or_path", "revision", "subfolder"}
for component in components:
spec = components_spec[component]
for attr in to_check_attrs:
if getattr(spec, "pretrained_model_name_or_path", None) is not None:
for attr in to_check_attrs:
assert component in index_contents, f"{component} should be present in index but isn't."
attr_value_from_index = index_contents[component][2][attr]
assert getattr(spec, attr) == attr_value_from_index
def test_workflow_map(self):
blocks = self.pipeline_blocks_class()
if blocks._workflow_map is None:
@@ -454,7 +483,8 @@ class TestModularModelCardContent:
"blocks_description",
"components_description",
"configs_section",
"io_specification_section",
"inputs_description",
"outputs_description",
"trigger_inputs_section",
"tags",
]
@@ -551,19 +581,18 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(inputs=inputs)
content = generate_modular_model_card_content(blocks)
io_section = content["io_specification_section"]
assert "**Inputs:**" in io_section
assert "prompt" in io_section
assert "num_steps" in io_section
assert "*optional*" in io_section
assert "defaults to `50`" in io_section
assert "**Required:**" in content["inputs_description"]
assert "**Optional:**" in content["inputs_description"]
assert "prompt" in content["inputs_description"]
assert "num_steps" in content["inputs_description"]
assert "default: `50`" in content["inputs_description"]
def test_inputs_description_empty(self):
"""Test handling of pipelines without specific inputs."""
blocks = self.create_mock_blocks(inputs=[])
content = generate_modular_model_card_content(blocks)
assert "No specific inputs defined" in content["io_specification_section"]
assert "No specific inputs defined" in content["inputs_description"]
def test_outputs_description_formatting(self):
"""Test that outputs are correctly formatted."""
@@ -573,16 +602,15 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(outputs=outputs)
content = generate_modular_model_card_content(blocks)
io_section = content["io_specification_section"]
assert "images" in io_section
assert "Generated images" in io_section
assert "images" in content["outputs_description"]
assert "Generated images" in content["outputs_description"]
def test_outputs_description_empty(self):
"""Test handling of pipelines without specific outputs."""
blocks = self.create_mock_blocks(outputs=[])
content = generate_modular_model_card_content(blocks)
assert "Standard pipeline outputs" in content["io_specification_section"]
assert "Standard pipeline outputs" in content["outputs_description"]
def test_trigger_inputs_section_with_triggers(self):
"""Test that trigger inputs section is generated when present."""