Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
b3bab6b273 make the glm image tests lighter. 2026-01-15 09:53:38 +05:30
11 changed files with 64 additions and 1570 deletions

View File

@@ -99,9 +99,3 @@ image.save("chroma-single-file.png")
[[autodoc]] ChromaImg2ImgPipeline
- all
- __call__
## ChromaInpaintPipeline
[[autodoc]] ChromaInpaintPipeline
- all
- __call__

View File

@@ -460,7 +460,6 @@ else:
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaInpaintPipeline",
"ChromaPipeline",
"ChronoEditPipeline",
"CLIPImageProjection",
@@ -1187,7 +1186,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaInpaintPipeline,
ChromaPipeline,
ChronoEditPipeline,
CLIPImageProjection,

View File

@@ -1573,6 +1573,8 @@ def _templated_context_parallel_attention(
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("Attention mask is not yet supported for templated attention.")
if is_causal:
raise ValueError("Causal attention is not yet supported for templated attention.")
if enable_gqa:

View File

@@ -761,14 +761,11 @@ class QwenImageTransformer2DModel(
_no_split_modules = ["QwenImageTransformerBlock"]
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
_repeated_blocks = ["QwenImageTransformerBlock"]
# Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
_cp_plan = {
"transformer_blocks.0": {
"": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"transformer_blocks.*": {
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
},
"pos_embed": {
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),

View File

@@ -4,7 +4,7 @@ import os
# Simple typed wrapper for parameter overrides
from dataclasses import asdict, dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, Optional, Union
from huggingface_hub import create_repo, hf_hub_download, upload_folder
from huggingface_hub.utils import (
@@ -42,54 +42,35 @@ class MellonParam:
fieldOptions: Optional[Dict[str, Any]] = None
onChange: Any = None
onSignal: Any = None
required_block_params: Optional[Union[str, List[str]]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dict for Mellon schema, excluding None values and name."""
data = asdict(self)
return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")}
return {k: v for k, v in data.items() if v is not None and k != "name"}
@classmethod
def image(cls) -> "MellonParam":
return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"])
return cls(name="image", label="Image", type="image", display="input")
@classmethod
def images(cls) -> "MellonParam":
return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"])
return cls(name="images", label="Images", type="image", display="output")
@classmethod
def control_image(cls, display: str = "input") -> "MellonParam":
return cls(
name="control_image",
label="Control Image",
type="image",
display=display,
required_block_params=["control_image"],
)
return cls(name="control_image", label="Control Image", type="image", display=display)
@classmethod
def latents(cls, display: str = "input") -> "MellonParam":
return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"])
return cls(name="latents", label="Latents", type="latents", display=display)
@classmethod
def image_latents(cls, display: str = "input") -> "MellonParam":
return cls(
name="image_latents",
label="Image Latents",
type="latents",
display=display,
required_block_params=["image_latents"],
)
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
@classmethod
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
return cls(
name="first_frame_latents",
label="First Frame Latents",
type="latents",
display=display,
required_block_params=["first_frame_latents"],
)
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display)
@classmethod
def image_latents_with_strength(cls) -> "MellonParam":
@@ -99,7 +80,6 @@ class MellonParam:
type="latents",
display="input",
onChange={"false": ["height", "width"], "true": ["strength"]},
required_block_params=["image_latents", "strength"],
)
@classmethod
@@ -115,13 +95,7 @@ class MellonParam:
@classmethod
def image_embeds(cls, display: str = "output") -> "MellonParam":
return cls(
name="image_embeds",
label="Image Embeddings",
type="image_embeds",
display=display,
required_block_params=["image_embeds"],
)
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display)
@classmethod
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
@@ -133,7 +107,6 @@ class MellonParam:
min=0.0,
max=1.0,
step=0.01,
required_block_params=["controlnet_conditioning_scale"],
)
@classmethod
@@ -146,7 +119,6 @@ class MellonParam:
min=0.0,
max=1.0,
step=0.01,
required_block_params=["control_guidance_start"],
)
@classmethod
@@ -159,43 +131,19 @@ class MellonParam:
min=0.0,
max=1.0,
step=0.01,
required_block_params=["control_guidance_end"],
)
@classmethod
def prompt(cls, default: str = "") -> "MellonParam":
return cls(
name="prompt",
label="Prompt",
type="string",
default=default,
display="textarea",
required_block_params=["prompt"],
)
return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea")
@classmethod
def negative_prompt(cls, default: str = "") -> "MellonParam":
return cls(
name="negative_prompt",
label="Negative Prompt",
type="string",
default=default,
display="textarea",
required_block_params=["negative_prompt"],
)
return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea")
@classmethod
def strength(cls, default: float = 0.5) -> "MellonParam":
return cls(
name="strength",
label="Strength",
type="float",
default=default,
min=0.0,
max=1.0,
step=0.01,
required_block_params=["strength"],
)
return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01)
@classmethod
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
@@ -212,77 +160,33 @@ class MellonParam:
@classmethod
def height(cls, default: int = 1024) -> "MellonParam":
return cls(
name="height",
label="Height",
type="int",
default=default,
min=64,
step=8,
required_block_params=["height"],
)
return cls(name="height", label="Height", type="int", default=default, min=64, step=8)
@classmethod
def width(cls, default: int = 1024) -> "MellonParam":
return cls(
name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
)
return cls(name="width", label="Width", type="int", default=default, min=64, step=8)
@classmethod
def seed(cls, default: int = 0) -> "MellonParam":
return cls(
name="seed",
label="Seed",
type="int",
default=default,
min=0,
max=4294967295,
display="random",
required_block_params=["generator"],
)
return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random")
@classmethod
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
return cls(
name="num_inference_steps",
label="Steps",
type="int",
default=default,
min=1,
max=100,
display="slider",
required_block_params=["num_inference_steps"],
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
)
@classmethod
def num_frames(cls, default: int = 81) -> "MellonParam":
return cls(
name="num_frames",
label="Frames",
type="int",
default=default,
min=1,
max=480,
display="slider",
required_block_params=["num_frames"],
)
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
@classmethod
def layers(cls, default: int = 4) -> "MellonParam":
return cls(
name="layers",
label="Layers",
type="int",
default=default,
min=1,
max=10,
display="slider",
required_block_params=["layers"],
)
return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider")
@classmethod
def videos(cls) -> "MellonParam":
return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"])
return cls(name="videos", label="Videos", type="video", display="output")
@classmethod
def vae(cls) -> "MellonParam":
@@ -292,9 +196,7 @@ class MellonParam:
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
"""
return cls(
name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
)
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
@classmethod
def image_encoder(cls) -> "MellonParam":
@@ -304,13 +206,7 @@ class MellonParam:
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
"""
return cls(
name="image_encoder",
label="Image Encoder",
type="diffusers_auto_model",
display="input",
required_block_params=["image_encoder"],
)
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input")
@classmethod
def unet(cls) -> "MellonParam":
@@ -340,13 +236,7 @@ class MellonParam:
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
the actual model.
"""
return cls(
name="controlnet",
label="ControlNet Model",
type="diffusers_auto_model",
display="input",
required_block_params=["controlnet"],
)
return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input")
@classmethod
def text_encoders(cls) -> "MellonParam":
@@ -358,13 +248,7 @@ class MellonParam:
'repo_id': '...'
} Use components.get_one(model_id) to retrieve each model.
"""
return cls(
name="text_encoders",
label="Text Encoders",
type="diffusers_auto_models",
display="input",
required_block_params=["text_encoder"],
)
return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input")
@classmethod
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
@@ -379,13 +263,7 @@ class MellonParam:
Output from Controlnet node, input to Denoise node.
"""
return cls(
name="controlnet_bundle",
label="ControlNet",
type="custom_controlnet",
display=display,
required_block_params="controlnet_image",
)
return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display)
@classmethod
def ip_adapter(cls) -> "MellonParam":
@@ -406,86 +284,6 @@ class MellonParam:
return cls(name="doc", label="Doc", type="string", display="output")
DEFAULT_NODE_SPECS = {
"controlnet": None,
"denoise": {
"inputs": [
MellonParam.embeddings(display="input"),
MellonParam.width(),
MellonParam.height(),
MellonParam.seed(),
MellonParam.num_inference_steps(),
MellonParam.guidance_scale(),
MellonParam.strength(),
MellonParam.image_latents_with_strength(),
MellonParam.image_latents(),
MellonParam.first_frame_latents(),
MellonParam.controlnet_bundle(display="input"),
],
"model_inputs": [
MellonParam.unet(),
MellonParam.guider(),
MellonParam.scheduler(),
],
"outputs": [
MellonParam.latents(display="output"),
MellonParam.latents_preview(),
MellonParam.doc(),
],
"required_inputs": ["embeddings"],
"required_model_inputs": ["unet", "scheduler"],
"block_name": "denoise",
},
"vae_encoder": {
"inputs": [
MellonParam.image(),
],
"model_inputs": [
MellonParam.vae(),
],
"outputs": [
MellonParam.image_latents(display="output"),
MellonParam.doc(),
],
"required_inputs": ["image"],
"required_model_inputs": ["vae"],
"block_name": "vae_encoder",
},
"text_encoder": {
"inputs": [
MellonParam.prompt(),
MellonParam.negative_prompt(),
],
"model_inputs": [
MellonParam.text_encoders(),
],
"outputs": [
MellonParam.embeddings(display="output"),
MellonParam.doc(),
],
"required_inputs": ["prompt"],
"required_model_inputs": ["text_encoders"],
"block_name": "text_encoder",
},
"decoder": {
"inputs": [
MellonParam.latents(display="input"),
],
"model_inputs": [
MellonParam.vae(),
],
"outputs": [
MellonParam.images(),
MellonParam.videos(),
MellonParam.doc(),
],
"required_inputs": ["latents"],
"required_model_inputs": ["vae"],
"block_name": "decode",
},
}
def mark_required(label: str, marker: str = " *") -> str:
"""Add required marker to label if not already present."""
if label.endswith(marker):
@@ -660,39 +458,20 @@ class MellonPipelineConfig:
default_dtype: Default dtype (e.g., "float16", "bfloat16")
"""
# Convert all node specs to Mellon format immediately
self.node_specs = node_specs
self.node_params = {}
for node_type, spec in node_specs.items():
if spec is None:
self.node_params[node_type] = None
else:
self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type)
self.label = label
self.default_repo = default_repo
self.default_dtype = default_dtype
@property
def node_params(self) -> Dict[str, Any]:
"""Lazily compute node_params from node_specs."""
params = {}
for node_type, spec in self.node_specs.items():
if spec is None:
params[node_type] = None
else:
params[node_type] = node_spec_to_mellon_dict(spec, node_type)
return params
def __repr__(self) -> str:
lines = [
f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"
]
for node_type, spec in self.node_specs.items():
if spec is None:
lines.append(f" {node_type}: None")
else:
inputs = [p.name for p in spec.get("inputs", [])]
model_inputs = [p.name for p in spec.get("model_inputs", [])]
outputs = [p.name for p in spec.get("outputs", [])]
lines.append(f" {node_type}:")
lines.append(f" inputs: {inputs}")
lines.append(f" model_inputs: {model_inputs}")
lines.append(f" outputs: {outputs}")
return "\n".join(lines)
node_types = list(self.node_params.keys())
return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})"
def to_dict(self) -> Dict[str, Any]:
"""Convert to a JSON-serializable dictionary."""
@@ -843,85 +622,3 @@ class MellonPipelineConfig:
return cls.from_json_file(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.")
@classmethod
def from_blocks(
cls,
blocks,
template: Dict[str, Optional[Dict[str, Any]]] = None,
label: str = "",
default_repo: str = "",
default_dtype: str = "bfloat16",
) -> "MellonPipelineConfig":
"""
Create MellonPipelineConfig by matching template against actual pipeline blocks.
"""
if template is None:
template = DEFAULT_NODE_SPECS
sub_block_map = dict(blocks.sub_blocks)
def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]:
"""Filter template spec params based on what the block actually supports."""
block_input_names = set(block.input_names)
block_output_names = set(block.intermediate_output_names)
block_component_names = set(block.component_names)
filtered_inputs = [
p
for p in template_spec.get("inputs", [])
if p.required_block_params is None
or all(name in block_input_names for name in p.required_block_params)
]
filtered_model_inputs = [
p
for p in template_spec.get("model_inputs", [])
if p.required_block_params is None
or all(name in block_component_names for name in p.required_block_params)
]
filtered_outputs = [
p
for p in template_spec.get("outputs", [])
if p.required_block_params is None
or all(name in block_output_names for name in p.required_block_params)
]
filtered_input_names = {p.name for p in filtered_inputs}
filtered_model_input_names = {p.name for p in filtered_model_inputs}
filtered_required_inputs = [
r for r in template_spec.get("required_inputs", []) if r in filtered_input_names
]
filtered_required_model_inputs = [
r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names
]
return {
"inputs": filtered_inputs,
"model_inputs": filtered_model_inputs,
"outputs": filtered_outputs,
"required_inputs": filtered_required_inputs,
"required_model_inputs": filtered_required_model_inputs,
"block_name": template_spec.get("block_name"),
}
# Build node specs
node_specs = {}
for node_type, template_spec in template.items():
if template_spec is None:
node_specs[node_type] = None
continue
block_name = template_spec.get("block_name")
if block_name is None or block_name not in sub_block_map:
node_specs[node_type] = None
continue
node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name])
return cls(
node_specs=node_specs,
label=label or getattr(blocks, "model_name", ""),
default_repo=default_repo,
default_dtype=default_dtype,
)

View File

@@ -155,7 +155,7 @@ else:
"AudioLDM2UNet2DConditionModel",
]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline", "ChromaInpaintPipeline"]
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
_import_structure["cogvideo"] = [
"CogVideoXPipeline",
"CogVideoXImageToVideoPipeline",
@@ -598,7 +598,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .blip_diffusion import BlipDiffusionPipeline
from .bria import BriaPipeline
from .bria_fibo import BriaFiboPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
from .chronoedit import ChronoEditPipeline
from .cogvideo import (
CogVideoXFunControlPipeline,

View File

@@ -24,7 +24,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
_import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
@@ -34,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .pipeline_chroma import ChromaPipeline
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
from .pipeline_chroma_inpainting import ChromaInpaintPipeline
else:
import sys

File diff suppressed because it is too large Load Diff

View File

@@ -260,10 +260,10 @@ class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
all_text.append(text)
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device)
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device)
self.text_encoder.to(device)
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length)
generated_ids.to(device)
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = self.text_processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False

View File

@@ -632,21 +632,6 @@ class ChromaImg2ImgPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class ChromaInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ChromaPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -21,20 +21,24 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel
from diffusers.utils import is_transformers_version
from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater
from ...testing_utils import enable_full_determinism, require_transformers_version_greater
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
if is_transformers_version(">=", "5.0.0.dev0"):
from transformers import GlmImageConfig, GlmImageForConditionalGeneration, GlmImageProcessor
from transformers import (
GlmImageConfig,
GlmImageForConditionalGeneration,
GlmImageImageProcessor,
GlmImageProcessor,
)
enable_full_determinism()
@require_transformers_version_greater("4.57.4")
@require_torch_accelerator
class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = GlmImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt"}
@@ -86,7 +90,23 @@ class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
torch.manual_seed(0)
vision_language_encoder = GlmImageForConditionalGeneration(glm_config)
processor = GlmImageProcessor.from_pretrained("zai-org/GLM-Image", subfolder="processor")
# Create small image_processor for testing instead of loading the huge processor
image_processor = GlmImageImageProcessor(
min_pixels=32 * 32,
max_pixels=32 * 32 * 4,
patch_size=8,
merge_size=1,
temporal_patch_size=1,
do_resize=True,
do_rescale=True,
do_normalize=True,
)
# Load the tokenizer from GLM-Image (small, just config files) - it has required attributes
# (image_token, grid_bos_token, grid_eos_token) that get properly serialized
processor_tokenizer = AutoTokenizer.from_pretrained("zai-org/GLM-Image", subfolder="processor")
processor = GlmImageProcessor(image_processor=image_processor, tokenizer=processor_tokenizer)
# Set chat template on processor (it checks self.chat_template, not self.tokenizer.chat_template)
processor.chat_template = processor_tokenizer.chat_template
torch.manual_seed(0)
# For GLM-Image, the relationship between components must satisfy: