mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-17 17:16:15 +08:00
Compare commits
14 Commits
remove-exp
...
transforme
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cec020988b | ||
|
|
926db24add | ||
|
|
37cfceef0d | ||
|
|
ea90a74ed4 | ||
|
|
96f08043a3 | ||
|
|
d0f279ce76 | ||
|
|
7f43cb1d79 | ||
|
|
c5e023fbe6 | ||
|
|
5efb81fa71 | ||
|
|
f8e50fab75 | ||
|
|
b351be2379 | ||
|
|
d8f4dd295f | ||
|
|
c152b1831c | ||
|
|
039324ae16 |
8
.github/workflows/pr_tests.yml
vendored
8
.github/workflows/pr_tests.yml
vendored
@@ -115,8 +115,8 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -247,8 +247,8 @@ jobs:
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
13
.github/workflows/pr_tests_gpu.yml
vendored
13
.github/workflows/pr_tests_gpu.yml
vendored
@@ -14,6 +14,7 @@ on:
|
||||
- "tests/pipelines/test_pipelines_common.py"
|
||||
- "tests/models/test_modeling_common.py"
|
||||
- "examples/**/*.py"
|
||||
- ".github/**.yml"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
@@ -131,8 +132,8 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -202,8 +203,8 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -264,8 +265,8 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
# uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
|
||||
@@ -99,3 +99,9 @@ image.save("chroma-single-file.png")
|
||||
[[autodoc]] ChromaImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChromaInpaintPipeline
|
||||
|
||||
[[autodoc]] ChromaInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -17,6 +17,9 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
@@ -30,6 +33,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
@unittest.skipIf(is_transformers_version(">=", "4.57.5"), "Size mismatch")
|
||||
class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
@@ -460,6 +460,7 @@ else:
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaInpaintPipeline",
|
||||
"ChromaPipeline",
|
||||
"ChronoEditPipeline",
|
||||
"CLIPImageProjection",
|
||||
@@ -1186,6 +1187,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaInpaintPipeline,
|
||||
ChromaPipeline,
|
||||
ChronoEditPipeline,
|
||||
CLIPImageProjection,
|
||||
|
||||
@@ -44,6 +44,7 @@ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d,
|
||||
torch.nn.Linear,
|
||||
torch.nn.Embedding,
|
||||
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
||||
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
||||
)
|
||||
|
||||
@@ -19,7 +19,13 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from torch import nn
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import (
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -549,17 +555,29 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
# Delete from tokenizer
|
||||
for token_id, token_to_remove in zip(token_ids, tokens):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
if is_transformers_version("<=", "4.58.0"):
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
del tokenizer._added_tokens_encoder[token_to_remove]
|
||||
elif is_transformers_version(">", "4.58.0"):
|
||||
del tokenizer.added_tokens_decoder[token_id]
|
||||
del tokenizer.added_tokens_encoder[token_to_remove]
|
||||
|
||||
# Make all token ids sequential in tokenizer
|
||||
key_id = 1
|
||||
for token_id in tokenizer.added_tokens_decoder:
|
||||
if token_id > last_special_token_id and token_id > last_special_token_id + key_id:
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
if is_transformers_version("<=", "4.58.0"):
|
||||
token = tokenizer._added_tokens_decoder[token_id]
|
||||
tokenizer._added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer._added_tokens_decoder[token_id]
|
||||
elif is_transformers_version(">", "4.58.0"):
|
||||
token = tokenizer.added_tokens_decoder[token_id]
|
||||
tokenizer.added_tokens_decoder[last_special_token_id + key_id] = token
|
||||
del tokenizer.added_tokens_decoder[token_id]
|
||||
if is_transformers_version("<=", "4.58.0"):
|
||||
tokenizer._added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
elif is_transformers_version(">", "4.58.0"):
|
||||
tokenizer.added_tokens_encoder[token.content] = last_special_token_id + key_id
|
||||
key_id += 1
|
||||
tokenizer._update_trie()
|
||||
# set correct total vocab size after removing tokens
|
||||
|
||||
@@ -1573,8 +1573,6 @@ def _templated_context_parallel_attention(
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("Attention mask is not yet supported for templated attention.")
|
||||
if is_causal:
|
||||
raise ValueError("Causal attention is not yet supported for templated attention.")
|
||||
if enable_gqa:
|
||||
|
||||
@@ -761,11 +761,14 @@ class QwenImageTransformer2DModel(
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
# Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
|
||||
_cp_plan = {
|
||||
"": {
|
||||
"transformer_blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"transformer_blocks.*": {
|
||||
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"pos_embed": {
|
||||
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
# Simple typed wrapper for parameter overrides
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from huggingface_hub.utils import (
|
||||
@@ -42,35 +42,54 @@ class MellonParam:
|
||||
fieldOptions: Optional[Dict[str, Any]] = None
|
||||
onChange: Any = None
|
||||
onSignal: Any = None
|
||||
required_block_params: Optional[Union[str, List[str]]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict for Mellon schema, excluding None values and name."""
|
||||
data = asdict(self)
|
||||
return {k: v for k, v in data.items() if v is not None and k != "name"}
|
||||
return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")}
|
||||
|
||||
@classmethod
|
||||
def image(cls) -> "MellonParam":
|
||||
return cls(name="image", label="Image", type="image", display="input")
|
||||
return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"])
|
||||
|
||||
@classmethod
|
||||
def images(cls) -> "MellonParam":
|
||||
return cls(name="images", label="Images", type="image", display="output")
|
||||
return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"])
|
||||
|
||||
@classmethod
|
||||
def control_image(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="control_image", label="Control Image", type="image", display=display)
|
||||
return cls(
|
||||
name="control_image",
|
||||
label="Control Image",
|
||||
type="image",
|
||||
display=display,
|
||||
required_block_params=["control_image"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="latents", label="Latents", type="latents", display=display)
|
||||
return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"])
|
||||
|
||||
@classmethod
|
||||
def image_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
|
||||
return cls(
|
||||
name="image_latents",
|
||||
label="Image Latents",
|
||||
type="latents",
|
||||
display=display,
|
||||
required_block_params=["image_latents"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display)
|
||||
return cls(
|
||||
name="first_frame_latents",
|
||||
label="First Frame Latents",
|
||||
type="latents",
|
||||
display=display,
|
||||
required_block_params=["first_frame_latents"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def image_latents_with_strength(cls) -> "MellonParam":
|
||||
@@ -80,6 +99,7 @@ class MellonParam:
|
||||
type="latents",
|
||||
display="input",
|
||||
onChange={"false": ["height", "width"], "true": ["strength"]},
|
||||
required_block_params=["image_latents", "strength"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -95,7 +115,13 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def image_embeds(cls, display: str = "output") -> "MellonParam":
|
||||
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display)
|
||||
return cls(
|
||||
name="image_embeds",
|
||||
label="Image Embeddings",
|
||||
type="image_embeds",
|
||||
display=display,
|
||||
required_block_params=["image_embeds"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
|
||||
@@ -107,6 +133,7 @@ class MellonParam:
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["controlnet_conditioning_scale"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -119,6 +146,7 @@ class MellonParam:
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["control_guidance_start"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -131,19 +159,43 @@ class MellonParam:
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["control_guidance_end"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prompt(cls, default: str = "") -> "MellonParam":
|
||||
return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea")
|
||||
return cls(
|
||||
name="prompt",
|
||||
label="Prompt",
|
||||
type="string",
|
||||
default=default,
|
||||
display="textarea",
|
||||
required_block_params=["prompt"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def negative_prompt(cls, default: str = "") -> "MellonParam":
|
||||
return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea")
|
||||
return cls(
|
||||
name="negative_prompt",
|
||||
label="Negative Prompt",
|
||||
type="string",
|
||||
default=default,
|
||||
display="textarea",
|
||||
required_block_params=["negative_prompt"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def strength(cls, default: float = 0.5) -> "MellonParam":
|
||||
return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01)
|
||||
return cls(
|
||||
name="strength",
|
||||
label="Strength",
|
||||
type="float",
|
||||
default=default,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["strength"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
|
||||
@@ -160,33 +212,77 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def height(cls, default: int = 1024) -> "MellonParam":
|
||||
return cls(name="height", label="Height", type="int", default=default, min=64, step=8)
|
||||
return cls(
|
||||
name="height",
|
||||
label="Height",
|
||||
type="int",
|
||||
default=default,
|
||||
min=64,
|
||||
step=8,
|
||||
required_block_params=["height"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def width(cls, default: int = 1024) -> "MellonParam":
|
||||
return cls(name="width", label="Width", type="int", default=default, min=64, step=8)
|
||||
return cls(
|
||||
name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def seed(cls, default: int = 0) -> "MellonParam":
|
||||
return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random")
|
||||
return cls(
|
||||
name="seed",
|
||||
label="Seed",
|
||||
type="int",
|
||||
default=default,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
display="random",
|
||||
required_block_params=["generator"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
|
||||
return cls(
|
||||
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
|
||||
name="num_inference_steps",
|
||||
label="Steps",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=100,
|
||||
display="slider",
|
||||
required_block_params=["num_inference_steps"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
|
||||
return cls(
|
||||
name="num_frames",
|
||||
label="Frames",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=480,
|
||||
display="slider",
|
||||
required_block_params=["num_frames"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def layers(cls, default: int = 4) -> "MellonParam":
|
||||
return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider")
|
||||
return cls(
|
||||
name="layers",
|
||||
label="Layers",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=10,
|
||||
display="slider",
|
||||
required_block_params=["layers"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
return cls(name="videos", label="Videos", type="video", display="output")
|
||||
return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"])
|
||||
|
||||
@classmethod
|
||||
def vae(cls) -> "MellonParam":
|
||||
@@ -196,7 +292,9 @@ class MellonParam:
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
|
||||
return cls(
|
||||
name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def image_encoder(cls) -> "MellonParam":
|
||||
@@ -206,7 +304,13 @@ class MellonParam:
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input")
|
||||
return cls(
|
||||
name="image_encoder",
|
||||
label="Image Encoder",
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
required_block_params=["image_encoder"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unet(cls) -> "MellonParam":
|
||||
@@ -236,7 +340,13 @@ class MellonParam:
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input")
|
||||
return cls(
|
||||
name="controlnet",
|
||||
label="ControlNet Model",
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
required_block_params=["controlnet"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def text_encoders(cls) -> "MellonParam":
|
||||
@@ -248,7 +358,13 @@ class MellonParam:
|
||||
'repo_id': '...'
|
||||
} Use components.get_one(model_id) to retrieve each model.
|
||||
"""
|
||||
return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input")
|
||||
return cls(
|
||||
name="text_encoders",
|
||||
label="Text Encoders",
|
||||
type="diffusers_auto_models",
|
||||
display="input",
|
||||
required_block_params=["text_encoder"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
|
||||
@@ -263,7 +379,13 @@ class MellonParam:
|
||||
|
||||
Output from Controlnet node, input to Denoise node.
|
||||
"""
|
||||
return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display)
|
||||
return cls(
|
||||
name="controlnet_bundle",
|
||||
label="ControlNet",
|
||||
type="custom_controlnet",
|
||||
display=display,
|
||||
required_block_params="controlnet_image",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def ip_adapter(cls) -> "MellonParam":
|
||||
@@ -284,6 +406,86 @@ class MellonParam:
|
||||
return cls(name="doc", label="Doc", type="string", display="output")
|
||||
|
||||
|
||||
DEFAULT_NODE_SPECS = {
|
||||
"controlnet": None,
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
MellonParam.embeddings(display="input"),
|
||||
MellonParam.width(),
|
||||
MellonParam.height(),
|
||||
MellonParam.seed(),
|
||||
MellonParam.num_inference_steps(),
|
||||
MellonParam.guidance_scale(),
|
||||
MellonParam.strength(),
|
||||
MellonParam.image_latents_with_strength(),
|
||||
MellonParam.image_latents(),
|
||||
MellonParam.first_frame_latents(),
|
||||
MellonParam.controlnet_bundle(display="input"),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.unet(),
|
||||
MellonParam.guider(),
|
||||
MellonParam.scheduler(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.latents(display="output"),
|
||||
MellonParam.latents_preview(),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["embeddings"],
|
||||
"required_model_inputs": ["unet", "scheduler"],
|
||||
"block_name": "denoise",
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
MellonParam.image(),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.vae(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.image_latents(display="output"),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["image"],
|
||||
"required_model_inputs": ["vae"],
|
||||
"block_name": "vae_encoder",
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
MellonParam.prompt(),
|
||||
MellonParam.negative_prompt(),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.text_encoders(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.embeddings(display="output"),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["prompt"],
|
||||
"required_model_inputs": ["text_encoders"],
|
||||
"block_name": "text_encoder",
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
MellonParam.latents(display="input"),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.vae(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.images(),
|
||||
MellonParam.videos(),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["latents"],
|
||||
"required_model_inputs": ["vae"],
|
||||
"block_name": "decode",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def mark_required(label: str, marker: str = " *") -> str:
|
||||
"""Add required marker to label if not already present."""
|
||||
if label.endswith(marker):
|
||||
@@ -458,20 +660,39 @@ class MellonPipelineConfig:
|
||||
default_dtype: Default dtype (e.g., "float16", "bfloat16")
|
||||
"""
|
||||
# Convert all node specs to Mellon format immediately
|
||||
self.node_params = {}
|
||||
for node_type, spec in node_specs.items():
|
||||
if spec is None:
|
||||
self.node_params[node_type] = None
|
||||
else:
|
||||
self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type)
|
||||
self.node_specs = node_specs
|
||||
|
||||
self.label = label
|
||||
self.default_repo = default_repo
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
@property
|
||||
def node_params(self) -> Dict[str, Any]:
|
||||
"""Lazily compute node_params from node_specs."""
|
||||
params = {}
|
||||
for node_type, spec in self.node_specs.items():
|
||||
if spec is None:
|
||||
params[node_type] = None
|
||||
else:
|
||||
params[node_type] = node_spec_to_mellon_dict(spec, node_type)
|
||||
return params
|
||||
|
||||
def __repr__(self) -> str:
|
||||
node_types = list(self.node_params.keys())
|
||||
return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})"
|
||||
lines = [
|
||||
f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"
|
||||
]
|
||||
for node_type, spec in self.node_specs.items():
|
||||
if spec is None:
|
||||
lines.append(f" {node_type}: None")
|
||||
else:
|
||||
inputs = [p.name for p in spec.get("inputs", [])]
|
||||
model_inputs = [p.name for p in spec.get("model_inputs", [])]
|
||||
outputs = [p.name for p in spec.get("outputs", [])]
|
||||
lines.append(f" {node_type}:")
|
||||
lines.append(f" inputs: {inputs}")
|
||||
lines.append(f" model_inputs: {model_inputs}")
|
||||
lines.append(f" outputs: {outputs}")
|
||||
return "\n".join(lines)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a JSON-serializable dictionary."""
|
||||
@@ -622,3 +843,85 @@ class MellonPipelineConfig:
|
||||
return cls.from_json_file(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
@classmethod
|
||||
def from_blocks(
|
||||
cls,
|
||||
blocks,
|
||||
template: Dict[str, Optional[Dict[str, Any]]] = None,
|
||||
label: str = "",
|
||||
default_repo: str = "",
|
||||
default_dtype: str = "bfloat16",
|
||||
) -> "MellonPipelineConfig":
|
||||
"""
|
||||
Create MellonPipelineConfig by matching template against actual pipeline blocks.
|
||||
"""
|
||||
if template is None:
|
||||
template = DEFAULT_NODE_SPECS
|
||||
|
||||
sub_block_map = dict(blocks.sub_blocks)
|
||||
|
||||
def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]:
|
||||
"""Filter template spec params based on what the block actually supports."""
|
||||
block_input_names = set(block.input_names)
|
||||
block_output_names = set(block.intermediate_output_names)
|
||||
block_component_names = set(block.component_names)
|
||||
|
||||
filtered_inputs = [
|
||||
p
|
||||
for p in template_spec.get("inputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_input_names for name in p.required_block_params)
|
||||
]
|
||||
filtered_model_inputs = [
|
||||
p
|
||||
for p in template_spec.get("model_inputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_component_names for name in p.required_block_params)
|
||||
]
|
||||
filtered_outputs = [
|
||||
p
|
||||
for p in template_spec.get("outputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_output_names for name in p.required_block_params)
|
||||
]
|
||||
|
||||
filtered_input_names = {p.name for p in filtered_inputs}
|
||||
filtered_model_input_names = {p.name for p in filtered_model_inputs}
|
||||
|
||||
filtered_required_inputs = [
|
||||
r for r in template_spec.get("required_inputs", []) if r in filtered_input_names
|
||||
]
|
||||
filtered_required_model_inputs = [
|
||||
r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names
|
||||
]
|
||||
|
||||
return {
|
||||
"inputs": filtered_inputs,
|
||||
"model_inputs": filtered_model_inputs,
|
||||
"outputs": filtered_outputs,
|
||||
"required_inputs": filtered_required_inputs,
|
||||
"required_model_inputs": filtered_required_model_inputs,
|
||||
"block_name": template_spec.get("block_name"),
|
||||
}
|
||||
|
||||
# Build node specs
|
||||
node_specs = {}
|
||||
for node_type, template_spec in template.items():
|
||||
if template_spec is None:
|
||||
node_specs[node_type] = None
|
||||
continue
|
||||
|
||||
block_name = template_spec.get("block_name")
|
||||
if block_name is None or block_name not in sub_block_map:
|
||||
node_specs[node_type] = None
|
||||
continue
|
||||
|
||||
node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name])
|
||||
|
||||
return cls(
|
||||
node_specs=node_specs,
|
||||
label=label or getattr(blocks, "model_name", ""),
|
||||
default_repo=default_repo,
|
||||
default_dtype=default_dtype,
|
||||
)
|
||||
|
||||
@@ -155,7 +155,7 @@ else:
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline", "ChromaInpaintPipeline"]
|
||||
_import_structure["cogvideo"] = [
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -598,7 +598,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
|
||||
from .chronoedit import ChronoEditPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
||||
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
|
||||
_import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_chroma import ChromaPipeline
|
||||
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
|
||||
from .pipeline_chroma_inpainting import ChromaInpaintPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
1197
src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py
Normal file
1197
src/diffusers/pipelines/chroma/pipeline_chroma_inpainting.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -278,6 +278,9 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = (
|
||||
input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ class MultilingualCLIP(PreTrainedModel):
|
||||
self.LinearTransformation = torch.nn.Linear(
|
||||
in_features=config.transformerDimensions, out_features=config.numDims
|
||||
)
|
||||
if hasattr(self, "post_init"):
|
||||
self.post_init()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]
|
||||
|
||||
@@ -782,6 +782,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
self.prefix_encoder = PrefixEncoder(config)
|
||||
self.dropout = torch.nn.Dropout(0.1)
|
||||
|
||||
if hasattr(self, "post_init"):
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embedding.word_embeddings
|
||||
|
||||
@@ -811,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", None)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
batch_size, seq_length = input_ids.shape
|
||||
|
||||
@@ -260,10 +260,10 @@ class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
||||
all_text.append(text)
|
||||
|
||||
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device)
|
||||
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device)
|
||||
|
||||
self.text_encoder.to(device)
|
||||
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length)
|
||||
generated_ids.to(device)
|
||||
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||
output_text = self.text_processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
|
||||
@@ -632,6 +632,21 @@ class ChromaImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -20,7 +20,9 @@ class TestAutoModel(unittest.TestCase):
|
||||
side_effect=[EnvironmentError("File not found"), {"model_type": "clip_text_model"}],
|
||||
)
|
||||
def test_load_from_config_transformers_with_subfolder(self, mock_load_config):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
|
||||
)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
|
||||
def test_load_from_config_without_subfolder(self):
|
||||
@@ -28,5 +30,7 @@ class TestAutoModel(unittest.TestCase):
|
||||
assert isinstance(model, LongformerModel)
|
||||
|
||||
def test_load_from_model_index(self):
|
||||
model = AutoModel.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder")
|
||||
model = AutoModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-torch", subfolder="text_encoder", use_safetensors=False
|
||||
)
|
||||
assert isinstance(model, CLIPTextModel)
|
||||
|
||||
@@ -108,7 +108,7 @@ class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"prompt": "dance monkey",
|
||||
"negative_prompt": "",
|
||||
"negative_prompt": "bad",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
|
||||
Reference in New Issue
Block a user