mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-16 00:25:51 +08:00
Compare commits
1 Commits
modular-mo
...
sayakpaul-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
537d4de2cc |
@@ -346,8 +346,6 @@
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -542,8 +540,6 @@
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/glm_image
|
||||
title: GLM-Image
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# GlmImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).
|
||||
|
||||
## GlmImageTransformer2DModel
|
||||
|
||||
[[autodoc]] GlmImageTransformer2DModel
|
||||
@@ -99,9 +99,3 @@ image.save("chroma-single-file.png")
|
||||
[[autodoc]] ChromaImg2ImgPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## ChromaInpaintPipeline
|
||||
|
||||
[[autodoc]] ChromaInpaintPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# GLM-Image
|
||||
|
||||
## Overview
|
||||
|
||||
GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
|
||||
|
||||
Model architecture: a hybrid autoregressive + diffusion decoder design、
|
||||
|
||||
+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.
|
||||
+ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
|
||||
|
||||
Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
|
||||
|
||||
+ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
|
||||
+ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
|
||||
|
||||
GLM-Image supports both text-to-image and image-to-image generation within a single model
|
||||
|
||||
+ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
|
||||
+ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).
|
||||
|
||||
## Usage examples
|
||||
|
||||
### Text to Image Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.pipelines.glm_image import GlmImagePipeline
|
||||
|
||||
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
|
||||
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
height=32 * 32,
|
||||
width=36 * 32,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=1.5,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output_t2i.png")
|
||||
```
|
||||
|
||||
### Image to Image Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.pipelines.glm_image import GlmImagePipeline
|
||||
from PIL import Image
|
||||
|
||||
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
|
||||
image_path = "cond.jpg"
|
||||
prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]
|
||||
height=33 * 32,
|
||||
width=32 * 32,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=1.5,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output_i2i.png")
|
||||
```
|
||||
|
||||
+ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.
|
||||
|
||||
## GlmImagePipeline
|
||||
|
||||
[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## GlmImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput
|
||||
@@ -23,7 +23,6 @@ from .utils import (
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -226,7 +225,6 @@ else:
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -460,7 +458,6 @@ else:
|
||||
"BriaFiboPipeline",
|
||||
"BriaPipeline",
|
||||
"ChromaImg2ImgPipeline",
|
||||
"ChromaInpaintPipeline",
|
||||
"ChromaPipeline",
|
||||
"ChronoEditPipeline",
|
||||
"CLIPImageProjection",
|
||||
@@ -495,7 +492,6 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -983,7 +979,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -1187,7 +1182,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
BriaFiboPipeline,
|
||||
BriaPipeline,
|
||||
ChromaImg2ImgPipeline,
|
||||
ChromaInpaintPipeline,
|
||||
ChromaPipeline,
|
||||
ChronoEditPipeline,
|
||||
CLIPImageProjection,
|
||||
@@ -1222,7 +1216,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImagePipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
|
||||
@@ -98,7 +98,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -209,7 +208,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
|
||||
@@ -1573,6 +1573,8 @@ def _templated_context_parallel_attention(
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("Attention mask is not yet supported for templated attention.")
|
||||
if is_causal:
|
||||
raise ValueError("Causal attention is not yet supported for templated attention.")
|
||||
if enable_gqa:
|
||||
|
||||
@@ -355,9 +355,8 @@ def _load_shard_file(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
@@ -403,7 +402,6 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
||||
@@ -430,7 +428,6 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
|
||||
@@ -1306,7 +1306,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
dduf_entries=dduf_entries,
|
||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
@@ -1361,12 +1360,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
|
||||
"Calling `cuda()` is not supported for `8-bit` quantized models. "
|
||||
" Please use the model as it is, since the model has already been set to the correct devices."
|
||||
)
|
||||
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
|
||||
elif is_bitsandbytes_version("<", "0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
@@ -1413,16 +1412,17 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
)
|
||||
|
||||
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
||||
if getattr(self, "is_loaded_in_8bit", False) and is_bitsandbytes_version("<", "0.48.0"):
|
||||
if getattr(self, "is_loaded_in_8bit", False):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `8-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.48.0."
|
||||
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
||||
" model has already been set to the correct devices and casted to the correct `dtype`."
|
||||
)
|
||||
elif getattr(self, "is_loaded_in_4bit", False) and is_bitsandbytes_version("<", "0.43.2"):
|
||||
elif is_bitsandbytes_version("<", "0.43.2"):
|
||||
raise ValueError(
|
||||
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
||||
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
||||
)
|
||||
|
||||
if _is_group_offload_enabled(self) and device_arg_or_kwarg_present:
|
||||
logger.warning(
|
||||
f"The module '{self.__class__.__name__}' is group offloaded and moving it using `.to()` is not supported."
|
||||
@@ -1592,7 +1592,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||
is_parallel_loading_enabled: Optional[bool] = False,
|
||||
disable_mmap: bool = False,
|
||||
):
|
||||
model_state_dict = model.state_dict()
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
@@ -1661,7 +1660,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if is_parallel_loading_enabled:
|
||||
|
||||
@@ -27,7 +27,6 @@ if is_torch_available():
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
|
||||
@@ -1,621 +0,0 @@
|
||||
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
||||
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
hidden_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
|
||||
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
||||
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
||||
|
||||
# (B, 2 * condition_dim)
|
||||
condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
|
||||
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
|
||||
conditioning = timesteps_emb + condition_emb
|
||||
conditioning = F.silu(conditioning)
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class GlmImageImageProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
hidden_size: int = 2560,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlmImageAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = hidden_states.dtype
|
||||
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
||||
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
||||
|
||||
emb = self.linear(temb)
|
||||
(
|
||||
shift_msa,
|
||||
c_shift_msa,
|
||||
scale_msa,
|
||||
c_scale_msa,
|
||||
gate_msa,
|
||||
c_gate_msa,
|
||||
shift_mlp,
|
||||
c_shift_mlp,
|
||||
scale_mlp,
|
||||
c_scale_mlp,
|
||||
gate_mlp,
|
||||
c_gate_mlp,
|
||||
) = emb.chunk(12, dim=1)
|
||||
|
||||
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
)
|
||||
|
||||
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model."""
|
||||
|
||||
def __init__(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
else:
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=1)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=1)
|
||||
|
||||
def get(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache.shape[0] != k.shape[0]:
|
||||
k_cache_expanded = self.k_cache.expand(k.shape[0], -1, -1, -1)
|
||||
v_cache_expanded = self.v_cache.expand(v.shape[0], -1, -1, -1)
|
||||
else:
|
||||
k_cache_expanded = self.k_cache
|
||||
v_cache_expanded = self.v_cache
|
||||
|
||||
k_cache = torch.cat([k_cache_expanded, k], dim=1)
|
||||
v_cache = torch.cat([v_cache_expanded, v], dim=1)
|
||||
return k_cache, v_cache
|
||||
|
||||
def clear(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode = None
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches."""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
|
||||
return self.caches[layer_idx]
|
||||
|
||||
def set_mode(self, mode: Optional[str]):
|
||||
if mode is not None and mode not in ["write", "read", "skip"]:
|
||||
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
|
||||
class GlmImageAttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
|
||||
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
||||
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = encoder_hidden_states.dtype
|
||||
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query).to(dtype=dtype)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key).to(dtype=dtype)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, text_seq_length:, :, :] = apply_rotary_emb(
|
||||
query[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2
|
||||
)
|
||||
key[:, text_seq_length:, :, :] = apply_rotary_emb(
|
||||
key[:, text_seq_length:, :, :], image_rotary_emb, sequence_dim=1, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.mode == "write":
|
||||
kv_cache.store(key, value)
|
||||
elif kv_cache.mode == "read":
|
||||
key, value = kv_cache.get(key, value)
|
||||
elif kv_cache.mode == "skip":
|
||||
pass
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
text_attn_mask = attention_mask
|
||||
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
||||
text_attn_mask = text_attn_mask.float().to(query.device)
|
||||
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
mix_attn_mask[:, :text_seq_length] = text_attn_mask
|
||||
mix_attn_mask = mix_attn_mask.unsqueeze(2)
|
||||
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
|
||||
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GlmImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2560,
|
||||
num_attention_heads: int = 64,
|
||||
attention_head_dim: int = 40,
|
||||
time_embed_dim: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Attention
|
||||
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm="layer_norm",
|
||||
elementwise_affine=False,
|
||||
eps=1e-5,
|
||||
processor=GlmImageAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Feedforward
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
norm_encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# 2. Attention
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
kv_cache=kv_cache,
|
||||
**attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
|
||||
# 3. Feedforward
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
|
||||
1 + c_scale_mlp.unsqueeze(1)
|
||||
) + c_shift_mlp.unsqueeze(1)
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output_context = self.ff(norm_encoder_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class GlmImageRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
dim_h, dim_w = self.dim // 2, self.dim // 2
|
||||
h_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
|
||||
)
|
||||
w_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
|
||||
)
|
||||
h_seq = torch.arange(height)
|
||||
w_seq = torch.arange(width)
|
||||
freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
|
||||
# Create position matrices for height and width
|
||||
# [height, 1, dim//4] and [1, width, dim//4]
|
||||
freqs_h = freqs_h.unsqueeze(1)
|
||||
freqs_w = freqs_w.unsqueeze(0)
|
||||
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
|
||||
freqs_h = freqs_h.expand(height, width, -1)
|
||||
freqs_w = freqs_w.expand(height, width, -1)
|
||||
|
||||
# Concatenate along last dimension to get [height, width, dim//2]
|
||||
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
||||
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
|
||||
freqs = freqs.reshape(height * width, -1)
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class GlmImageAdaLayerNormContinuous(nn.Module):
|
||||
"""
|
||||
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
||||
Linear on conditioning embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# *** NO SiLU here ***
|
||||
emb = self.linear(conditioning_embedding.to(x.dtype))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `40`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `64`):
|
||||
The number of heads to use for multi-head attention.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_embed_dim (`int`, defaults to `1472`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
condition_dim (`int`, defaults to `256`):
|
||||
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
||||
crop_coords).
|
||||
pos_embed_max_size (`int`, defaults to `128`):
|
||||
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
||||
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
||||
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
||||
patch_size => 128 * 8 * 2 => 2048`.
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
||||
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"GlmImageTransformerBlock",
|
||||
"GlmImageImageProjector",
|
||||
"GlmImageImageProjector",
|
||||
]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
|
||||
_skip_keys = ["kv_caches"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_layers: int = 30,
|
||||
attention_head_dim: int = 40,
|
||||
num_attention_heads: int = 64,
|
||||
text_embed_dim: int = 1472,
|
||||
time_embed_dim: int = 512,
|
||||
condition_dim: int = 256,
|
||||
prior_vq_quantizer_codebook_size: int = 16384,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
|
||||
# Each of these are sincos embeddings of shape 2 * condition_dim
|
||||
pooled_projection_dim = 2 * 2 * condition_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels
|
||||
|
||||
# 1. RoPE
|
||||
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
|
||||
|
||||
# 2. Patch & Text-timestep embedding
|
||||
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
|
||||
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
|
||||
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
|
||||
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
|
||||
|
||||
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=time_embed_dim,
|
||||
condition_dim=condition_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
timesteps_dim=time_embed_dim,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
prior_token_id: torch.Tensor,
|
||||
prior_token_drop: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
kv_caches: Optional[GlmImageKVCache] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Patch & Timestep embeddings
|
||||
p = self.config.patch_size
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
hidden_states = self.image_projector(hidden_states)
|
||||
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
|
||||
prior_embedding = self.prior_token_embedding(prior_token_id)
|
||||
prior_embedding[prior_token_drop] *= 0.0
|
||||
prior_hidden_states = self.prior_projector(prior_embedding)
|
||||
|
||||
hidden_states = hidden_states + prior_hidden_states
|
||||
|
||||
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
|
||||
|
||||
# 3. Transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_cache=kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
|
||||
# Rearrange tensor from (B, H_p, W_p, C, p, p) to (B, C, H_p * p, W_p * p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -761,14 +761,11 @@ class QwenImageTransformer2DModel(
|
||||
_no_split_modules = ["QwenImageTransformerBlock"]
|
||||
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
||||
_repeated_blocks = ["QwenImageTransformerBlock"]
|
||||
# Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
|
||||
_cp_plan = {
|
||||
"transformer_blocks.0": {
|
||||
"": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
"encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"transformer_blocks.*": {
|
||||
"modulate_index": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
"encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
|
||||
},
|
||||
"pos_embed": {
|
||||
0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
# Simple typed wrapper for parameter overrides
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from huggingface_hub.utils import (
|
||||
@@ -23,18 +23,10 @@ logger = logging.getLogger(__name__)
|
||||
@dataclass(frozen=True)
|
||||
class MellonParam:
|
||||
"""
|
||||
Parameter definition for Mellon nodes.
|
||||
Parameter definition for Mellon nodes.
|
||||
|
||||
Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with
|
||||
MellonParam(name="...", label="...", type="...").
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Custom param
|
||||
MellonParam(name="my_param", label="My Param", type="float", default=0.5)
|
||||
# Output in Mellon node definition:
|
||||
# "my_param": {"label": "My Param", "type": "float", "default": 0.5}
|
||||
```
|
||||
Use factory methods for common params (e.g., MellonParam.seed()) or create custom ones with MellonParam(name="...",
|
||||
label="...", type="...").
|
||||
"""
|
||||
|
||||
name: str
|
||||
@@ -50,165 +42,55 @@ class MellonParam:
|
||||
fieldOptions: Optional[Dict[str, Any]] = None
|
||||
onChange: Any = None
|
||||
onSignal: Any = None
|
||||
required_block_params: Optional[Union[str, List[str]]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict for Mellon schema, excluding None values and name."""
|
||||
data = asdict(self)
|
||||
return {k: v for k, v in data.items() if v is not None and k not in ("name", "required_block_params")}
|
||||
return {k: v for k, v in data.items() if v is not None and k != "name"}
|
||||
|
||||
@classmethod
|
||||
def image(cls) -> "MellonParam":
|
||||
"""
|
||||
Image input parameter.
|
||||
|
||||
Mellon node definition:
|
||||
"image": {"label": "Image", "type": "image", "display": "input"}
|
||||
"""
|
||||
return cls(name="image", label="Image", type="image", display="input", required_block_params=["image"])
|
||||
return cls(name="image", label="Image", type="image", display="input")
|
||||
|
||||
@classmethod
|
||||
def images(cls) -> "MellonParam":
|
||||
"""
|
||||
Images output parameter.
|
||||
|
||||
Mellon node definition:
|
||||
"images": {"label": "Images", "type": "image", "display": "output"}
|
||||
"""
|
||||
return cls(name="images", label="Images", type="image", display="output", required_block_params=["images"])
|
||||
return cls(name="images", label="Images", type="image", display="output")
|
||||
|
||||
@classmethod
|
||||
def control_image(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
Control image parameter for ControlNet.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"control_image": {"label": "Control Image", "type": "image", "display": "input"}
|
||||
"""
|
||||
return cls(
|
||||
name="control_image",
|
||||
label="Control Image",
|
||||
type="image",
|
||||
display=display,
|
||||
required_block_params=["control_image"],
|
||||
)
|
||||
return cls(name="control_image", label="Control Image", type="image", display=display)
|
||||
|
||||
@classmethod
|
||||
def latents(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
Latents parameter.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"latents": {"label": "Latents", "type": "latents", "display": "input"}
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"latents": {"label": "Latents", "type": "latents", "display": "output"}
|
||||
"""
|
||||
return cls(name="latents", label="Latents", type="latents", display=display, required_block_params=["latents"])
|
||||
return cls(name="latents", label="Latents", type="latents", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_latents(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
Image latents parameter for img2img workflows.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"image_latents": {"label": "Image Latents", "type": "latents", "display": "input"}
|
||||
"""
|
||||
return cls(
|
||||
name="image_latents",
|
||||
label="Image Latents",
|
||||
type="latents",
|
||||
display=display,
|
||||
required_block_params=["image_latents"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
First frame latents for video generation.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"first_frame_latents": {"label": "First Frame Latents", "type": "latents", "display": "input"}
|
||||
"""
|
||||
return cls(
|
||||
name="first_frame_latents",
|
||||
label="First Frame Latents",
|
||||
type="latents",
|
||||
display=display,
|
||||
required_block_params=["first_frame_latents"],
|
||||
)
|
||||
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_latents_with_strength(cls) -> "MellonParam":
|
||||
"""
|
||||
Image latents with strength-based onChange behavior. When connected, shows strength slider; when disconnected,
|
||||
shows height/width.
|
||||
|
||||
Mellon node definition:
|
||||
"image_latents": {
|
||||
"label": "Image Latents", "type": "latents", "display": "input", "onChange": {"false": ["height",
|
||||
"width"], "true": ["strength"]}
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="image_latents",
|
||||
label="Image Latents",
|
||||
type="latents",
|
||||
display="input",
|
||||
onChange={"false": ["height", "width"], "true": ["strength"]},
|
||||
required_block_params=["image_latents", "strength"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latents_preview(cls) -> "MellonParam":
|
||||
"""
|
||||
Latents preview output for visualizing latents in the UI.
|
||||
|
||||
Mellon node definition:
|
||||
"latents_preview": {"label": "Latents Preview", "type": "latent", "display": "output"}
|
||||
`Latents Preview` is a special output parameter that is used to preview the latents in the UI.
|
||||
"""
|
||||
return cls(name="latents_preview", label="Latents Preview", type="latent", display="output")
|
||||
|
||||
@classmethod
|
||||
def embeddings(cls, display: str = "output") -> "MellonParam":
|
||||
"""
|
||||
Text embeddings parameter.
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "output"}
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"embeddings": {"label": "Text Embeddings", "type": "embeddings", "display": "input"}
|
||||
"""
|
||||
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_embeds(cls, display: str = "output") -> "MellonParam":
|
||||
"""
|
||||
Image embeddings parameter for IP-Adapter workflows.
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"image_embeds": {"label": "Image Embeddings", "type": "image_embeds", "display": "output"}
|
||||
"""
|
||||
return cls(
|
||||
name="image_embeds",
|
||||
label="Image Embeddings",
|
||||
type="image_embeds",
|
||||
display=display,
|
||||
required_block_params=["image_embeds"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
|
||||
"""
|
||||
ControlNet conditioning scale slider.
|
||||
|
||||
Mellon node definition (default=0.5):
|
||||
"controlnet_conditioning_scale": {
|
||||
"label": "Controlnet Conditioning Scale", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0,
|
||||
"step": 0.01
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="controlnet_conditioning_scale",
|
||||
label="Controlnet Conditioning Scale",
|
||||
@@ -217,20 +99,10 @@ class MellonParam:
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["controlnet_conditioning_scale"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def control_guidance_start(cls, default: float = 0.0) -> "MellonParam":
|
||||
"""
|
||||
Control guidance start timestep.
|
||||
|
||||
Mellon node definition (default=0.0):
|
||||
"control_guidance_start": {
|
||||
"label": "Control Guidance Start", "type": "float", "default": 0.0, "min": 0.0, "max": 1.0, "step":
|
||||
0.01
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="control_guidance_start",
|
||||
label="Control Guidance Start",
|
||||
@@ -239,19 +111,10 @@ class MellonParam:
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["control_guidance_start"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def control_guidance_end(cls, default: float = 1.0) -> "MellonParam":
|
||||
"""
|
||||
Control guidance end timestep.
|
||||
|
||||
Mellon node definition (default=1.0):
|
||||
"control_guidance_end": {
|
||||
"label": "Control Guidance End", "type": "float", "default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="control_guidance_end",
|
||||
label="Control Guidance End",
|
||||
@@ -260,73 +123,22 @@ class MellonParam:
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["control_guidance_end"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prompt(cls, default: str = "") -> "MellonParam":
|
||||
"""
|
||||
Text prompt input as textarea.
|
||||
|
||||
Mellon node definition (default=""):
|
||||
"prompt": {"label": "Prompt", "type": "string", "default": "", "display": "textarea"}
|
||||
"""
|
||||
return cls(
|
||||
name="prompt",
|
||||
label="Prompt",
|
||||
type="string",
|
||||
default=default,
|
||||
display="textarea",
|
||||
required_block_params=["prompt"],
|
||||
)
|
||||
return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea")
|
||||
|
||||
@classmethod
|
||||
def negative_prompt(cls, default: str = "") -> "MellonParam":
|
||||
"""
|
||||
Negative prompt input as textarea.
|
||||
|
||||
Mellon node definition (default=""):
|
||||
"negative_prompt": {"label": "Negative Prompt", "type": "string", "default": "", "display": "textarea"}
|
||||
"""
|
||||
return cls(
|
||||
name="negative_prompt",
|
||||
label="Negative Prompt",
|
||||
type="string",
|
||||
default=default,
|
||||
display="textarea",
|
||||
required_block_params=["negative_prompt"],
|
||||
)
|
||||
return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea")
|
||||
|
||||
@classmethod
|
||||
def strength(cls, default: float = 0.5) -> "MellonParam":
|
||||
"""
|
||||
Denoising strength for img2img.
|
||||
|
||||
Mellon node definition (default=0.5):
|
||||
"strength": {"label": "Strength", "type": "float", "default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}
|
||||
"""
|
||||
return cls(
|
||||
name="strength",
|
||||
label="Strength",
|
||||
type="float",
|
||||
default=default,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["strength"],
|
||||
)
|
||||
return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01)
|
||||
|
||||
@classmethod
|
||||
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
|
||||
"""
|
||||
CFG guidance scale slider.
|
||||
|
||||
Mellon node definition (default=5.0):
|
||||
"guidance_scale": {
|
||||
"label": "Guidance Scale", "type": "float", "display": "slider", "default": 5.0, "min": 1.0, "max":
|
||||
30.0, "step": 0.1
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="guidance_scale",
|
||||
label="Guidance Scale",
|
||||
@@ -340,273 +152,103 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def height(cls, default: int = 1024) -> "MellonParam":
|
||||
"""
|
||||
Image height in pixels.
|
||||
|
||||
Mellon node definition (default=1024):
|
||||
"height": {"label": "Height", "type": "int", "default": 1024, "min": 64, "step": 8}
|
||||
"""
|
||||
return cls(
|
||||
name="height",
|
||||
label="Height",
|
||||
type="int",
|
||||
default=default,
|
||||
min=64,
|
||||
step=8,
|
||||
required_block_params=["height"],
|
||||
)
|
||||
return cls(name="height", label="Height", type="int", default=default, min=64, step=8)
|
||||
|
||||
@classmethod
|
||||
def width(cls, default: int = 1024) -> "MellonParam":
|
||||
"""
|
||||
Image width in pixels.
|
||||
|
||||
Mellon node definition (default=1024):
|
||||
"width": {"label": "Width", "type": "int", "default": 1024, "min": 64, "step": 8}
|
||||
"""
|
||||
return cls(
|
||||
name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
|
||||
)
|
||||
return cls(name="width", label="Width", type="int", default=default, min=64, step=8)
|
||||
|
||||
@classmethod
|
||||
def seed(cls, default: int = 0) -> "MellonParam":
|
||||
"""
|
||||
Random seed with randomize button.
|
||||
|
||||
Mellon node definition (default=0):
|
||||
"seed": {
|
||||
"label": "Seed", "type": "int", "default": 0, "min": 0, "max": 4294967295, "display": "random"
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="seed",
|
||||
label="Seed",
|
||||
type="int",
|
||||
default=default,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
display="random",
|
||||
required_block_params=["generator"],
|
||||
)
|
||||
return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random")
|
||||
|
||||
@classmethod
|
||||
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
|
||||
"""
|
||||
Number of denoising steps slider.
|
||||
|
||||
Mellon node definition (default=25):
|
||||
"num_inference_steps": {
|
||||
"label": "Steps", "type": "int", "default": 25, "min": 1, "max": 100, "display": "slider"
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="num_inference_steps",
|
||||
label="Steps",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=100,
|
||||
display="slider",
|
||||
required_block_params=["num_inference_steps"],
|
||||
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
"""
|
||||
Number of video frames slider.
|
||||
|
||||
Mellon node definition (default=81):
|
||||
"num_frames": {"label": "Frames", "type": "int", "default": 81, "min": 1, "max": 480, "display": "slider"}
|
||||
"""
|
||||
return cls(
|
||||
name="num_frames",
|
||||
label="Frames",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=480,
|
||||
display="slider",
|
||||
required_block_params=["num_frames"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def layers(cls, default: int = 4) -> "MellonParam":
|
||||
"""
|
||||
Number of layers slider (for layered diffusion).
|
||||
|
||||
Mellon node definition (default=4):
|
||||
"layers": {"label": "Layers", "type": "int", "default": 4, "min": 1, "max": 10, "display": "slider"}
|
||||
"""
|
||||
return cls(
|
||||
name="layers",
|
||||
label="Layers",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=10,
|
||||
display="slider",
|
||||
required_block_params=["layers"],
|
||||
)
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
"""
|
||||
Video output parameter.
|
||||
|
||||
Mellon node definition:
|
||||
"videos": {"label": "Videos", "type": "video", "display": "output"}
|
||||
"""
|
||||
return cls(name="videos", label="Videos", type="video", display="output", required_block_params=["videos"])
|
||||
return cls(name="videos", label="Videos", type="video", display="output")
|
||||
|
||||
@classmethod
|
||||
def vae(cls) -> "MellonParam":
|
||||
"""
|
||||
VAE model input.
|
||||
VAE model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"vae": {"label": "VAE", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(
|
||||
name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def image_encoder(cls) -> "MellonParam":
|
||||
"""
|
||||
Image encoder model input.
|
||||
|
||||
Mellon node definition:
|
||||
"image_encoder": {"label": "Image Encoder", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
"""
|
||||
return cls(
|
||||
name="image_encoder",
|
||||
label="Image Encoder",
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
required_block_params=["image_encoder"],
|
||||
)
|
||||
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def unet(cls) -> "MellonParam":
|
||||
"""
|
||||
Denoising model (UNet/Transformer) input.
|
||||
Denoising model (UNet/Transformer) info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"unet": {"label": "Denoise Model", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="unet", label="Denoise Model", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def scheduler(cls) -> "MellonParam":
|
||||
"""
|
||||
Scheduler model input.
|
||||
Scheduler model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"scheduler": {"label": "Scheduler", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id'. Use
|
||||
components.get_one(model_id) to retrieve the actual scheduler.
|
||||
Contains keys like 'model_id', 'repo_id' etc. Use components.get_one(model_id) to retrieve the actual
|
||||
scheduler.
|
||||
"""
|
||||
return cls(name="scheduler", label="Scheduler", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def controlnet(cls) -> "MellonParam":
|
||||
"""
|
||||
ControlNet model input.
|
||||
ControlNet model info dict.
|
||||
|
||||
Mellon node definition:
|
||||
"controlnet": {"label": "ControlNet Model", "type": "diffusers_auto_model", "display": "input"}
|
||||
|
||||
Note: The value received is a model info dict with keys like 'model_id', 'repo_id', 'execution_device'. Use
|
||||
components.get_one(model_id) to retrieve the actual model.
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(
|
||||
name="controlnet",
|
||||
label="ControlNet Model",
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
required_block_params=["controlnet"],
|
||||
)
|
||||
return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def text_encoders(cls) -> "MellonParam":
|
||||
"""
|
||||
Text encoders dict input (multiple encoders).
|
||||
Dict of text encoder model info dicts.
|
||||
|
||||
Mellon node definition:
|
||||
"text_encoders": {"label": "Text Encoders", "type": "diffusers_auto_models", "display": "input"}
|
||||
|
||||
Note: The value received is a dict of model info dicts:
|
||||
{
|
||||
'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
|
||||
'repo_id': '...'
|
||||
}
|
||||
Use components.get_one(model_id) to retrieve each model.
|
||||
Structure: {
|
||||
'text_encoder': {'model_id': ..., 'execution_device': ..., ...}, 'tokenizer': {'model_id': ..., ...},
|
||||
'repo_id': '...'
|
||||
} Use components.get_one(model_id) to retrieve each model.
|
||||
"""
|
||||
return cls(
|
||||
name="text_encoders",
|
||||
label="Text Encoders",
|
||||
type="diffusers_auto_models",
|
||||
display="input",
|
||||
required_block_params=["text_encoder"],
|
||||
)
|
||||
return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input")
|
||||
|
||||
@classmethod
|
||||
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
|
||||
"""
|
||||
ControlNet bundle containing model and processed control inputs. Output from ControlNet node, input to Denoise
|
||||
node.
|
||||
ControlNet bundle containing model info and processed control inputs.
|
||||
|
||||
Mellon node definition (display="input"):
|
||||
"controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "input"}
|
||||
Structure: {
|
||||
'controlnet': {'model_id': ..., ...}, # controlnet model info dict 'control_image': ..., # processed
|
||||
control image/embeddings 'controlnet_conditioning_scale': ..., ... # other inputs expected by denoise
|
||||
blocks
|
||||
}
|
||||
|
||||
Mellon node definition (display="output"):
|
||||
"controlnet_bundle": {"label": "ControlNet", "type": "custom_controlnet", "display": "output"}
|
||||
|
||||
Note: The value is a dict containing:
|
||||
{
|
||||
'controlnet': {'model_id': ..., ...}, # controlnet model info 'control_image': ..., # processed control
|
||||
image/embeddings 'controlnet_conditioning_scale': ..., # and other denoise block inputs
|
||||
}
|
||||
Output from Controlnet node, input to Denoise node.
|
||||
"""
|
||||
return cls(
|
||||
name="controlnet_bundle",
|
||||
label="ControlNet",
|
||||
type="custom_controlnet",
|
||||
display=display,
|
||||
required_block_params="controlnet_image",
|
||||
)
|
||||
return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display)
|
||||
|
||||
@classmethod
|
||||
def ip_adapter(cls) -> "MellonParam":
|
||||
"""
|
||||
IP-Adapter input.
|
||||
|
||||
Mellon node definition:
|
||||
"ip_adapter": {"label": "IP Adapter", "type": "custom_ip_adapter", "display": "input"}
|
||||
"""
|
||||
return cls(name="ip_adapter", label="IP Adapter", type="custom_ip_adapter", display="input")
|
||||
|
||||
@classmethod
|
||||
def guider(cls) -> "MellonParam":
|
||||
"""
|
||||
Custom guider input. When connected, hides the guidance_scale slider.
|
||||
|
||||
Mellon node definition:
|
||||
"guider": {
|
||||
"label": "Guider", "type": "custom_guider", "display": "input", "onChange": {false: ["guidance_scale"],
|
||||
true: []}
|
||||
}
|
||||
"""
|
||||
return cls(
|
||||
name="guider",
|
||||
label="Guider",
|
||||
@@ -617,96 +259,9 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def doc(cls) -> "MellonParam":
|
||||
"""
|
||||
Documentation output for inspecting the underlying modular pipeline.
|
||||
|
||||
Mellon node definition:
|
||||
"doc": {"label": "Doc", "type": "string", "display": "output"}
|
||||
"""
|
||||
return cls(name="doc", label="Doc", type="string", display="output")
|
||||
|
||||
|
||||
DEFAULT_NODE_SPECS = {
|
||||
"controlnet": None,
|
||||
"denoise": {
|
||||
"inputs": [
|
||||
MellonParam.embeddings(display="input"),
|
||||
MellonParam.width(),
|
||||
MellonParam.height(),
|
||||
MellonParam.seed(),
|
||||
MellonParam.num_inference_steps(),
|
||||
MellonParam.num_frames(),
|
||||
MellonParam.guidance_scale(),
|
||||
MellonParam.strength(),
|
||||
MellonParam.image_latents_with_strength(),
|
||||
MellonParam.image_latents(),
|
||||
MellonParam.first_frame_latents(),
|
||||
MellonParam.controlnet_bundle(display="input"),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.unet(),
|
||||
MellonParam.guider(),
|
||||
MellonParam.scheduler(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.latents(display="output"),
|
||||
MellonParam.latents_preview(),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["embeddings"],
|
||||
"required_model_inputs": ["unet", "scheduler"],
|
||||
"block_name": "denoise",
|
||||
},
|
||||
"vae_encoder": {
|
||||
"inputs": [
|
||||
MellonParam.image(),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.vae(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.image_latents(display="output"),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["image"],
|
||||
"required_model_inputs": ["vae"],
|
||||
"block_name": "vae_encoder",
|
||||
},
|
||||
"text_encoder": {
|
||||
"inputs": [
|
||||
MellonParam.prompt(),
|
||||
MellonParam.negative_prompt(),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.text_encoders(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.embeddings(display="output"),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["prompt"],
|
||||
"required_model_inputs": ["text_encoders"],
|
||||
"block_name": "text_encoder",
|
||||
},
|
||||
"decoder": {
|
||||
"inputs": [
|
||||
MellonParam.latents(display="input"),
|
||||
],
|
||||
"model_inputs": [
|
||||
MellonParam.vae(),
|
||||
],
|
||||
"outputs": [
|
||||
MellonParam.images(),
|
||||
MellonParam.videos(),
|
||||
MellonParam.doc(),
|
||||
],
|
||||
"required_inputs": ["latents"],
|
||||
"required_model_inputs": ["vae"],
|
||||
"block_name": "decode",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def mark_required(label: str, marker: str = " *") -> str:
|
||||
"""Add required marker to label if not already present."""
|
||||
if label.endswith(marker):
|
||||
@@ -881,42 +436,20 @@ class MellonPipelineConfig:
|
||||
default_dtype: Default dtype (e.g., "float16", "bfloat16")
|
||||
"""
|
||||
# Convert all node specs to Mellon format immediately
|
||||
self.node_specs = node_specs
|
||||
self.node_params = {}
|
||||
for node_type, spec in node_specs.items():
|
||||
if spec is None:
|
||||
self.node_params[node_type] = None
|
||||
else:
|
||||
self.node_params[node_type] = node_spec_to_mellon_dict(spec, node_type)
|
||||
|
||||
self.label = label
|
||||
self.default_repo = default_repo
|
||||
self.default_dtype = default_dtype
|
||||
|
||||
@property
|
||||
def node_params(self) -> Dict[str, Any]:
|
||||
"""Lazily compute node_params from node_specs."""
|
||||
if self.node_specs is None:
|
||||
return self._node_params
|
||||
|
||||
params = {}
|
||||
for node_type, spec in self.node_specs.items():
|
||||
if spec is None:
|
||||
params[node_type] = None
|
||||
else:
|
||||
params[node_type] = node_spec_to_mellon_dict(spec, node_type)
|
||||
return params
|
||||
|
||||
def __repr__(self) -> str:
|
||||
lines = [
|
||||
f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"
|
||||
]
|
||||
for node_type, spec in self.node_specs.items():
|
||||
if spec is None:
|
||||
lines.append(f" {node_type}: None")
|
||||
else:
|
||||
inputs = [p.name for p in spec.get("inputs", [])]
|
||||
model_inputs = [p.name for p in spec.get("model_inputs", [])]
|
||||
outputs = [p.name for p in spec.get("outputs", [])]
|
||||
lines.append(f" {node_type}:")
|
||||
lines.append(f" inputs: {inputs}")
|
||||
lines.append(f" model_inputs: {model_inputs}")
|
||||
lines.append(f" outputs: {outputs}")
|
||||
return "\n".join(lines)
|
||||
node_types = list(self.node_params.keys())
|
||||
return f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r}, node_params={node_types})"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a JSON-serializable dictionary."""
|
||||
@@ -935,8 +468,7 @@ class MellonPipelineConfig:
|
||||
Note: The mellon_params are already in Mellon format when loading from JSON.
|
||||
"""
|
||||
instance = cls.__new__(cls)
|
||||
instance.node_specs = None
|
||||
instance._node_params = data.get("node_params", {})
|
||||
instance.node_params = data.get("node_params", {})
|
||||
instance.label = data.get("label", "")
|
||||
instance.default_repo = data.get("default_repo", "")
|
||||
instance.default_dtype = data.get("default_dtype", "")
|
||||
@@ -1068,85 +600,3 @@ class MellonPipelineConfig:
|
||||
return cls.from_json_file(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"The config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
@classmethod
|
||||
def from_blocks(
|
||||
cls,
|
||||
blocks,
|
||||
template: Dict[str, Optional[Dict[str, Any]]] = None,
|
||||
label: str = "",
|
||||
default_repo: str = "",
|
||||
default_dtype: str = "bfloat16",
|
||||
) -> "MellonPipelineConfig":
|
||||
"""
|
||||
Create MellonPipelineConfig by matching template against actual pipeline blocks.
|
||||
"""
|
||||
if template is None:
|
||||
template = DEFAULT_NODE_SPECS
|
||||
|
||||
sub_block_map = dict(blocks.sub_blocks)
|
||||
|
||||
def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]:
|
||||
"""Filter template spec params based on what the block actually supports."""
|
||||
block_input_names = set(block.input_names)
|
||||
block_output_names = set(block.intermediate_output_names)
|
||||
block_component_names = set(block.component_names)
|
||||
|
||||
filtered_inputs = [
|
||||
p
|
||||
for p in template_spec.get("inputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_input_names for name in p.required_block_params)
|
||||
]
|
||||
filtered_model_inputs = [
|
||||
p
|
||||
for p in template_spec.get("model_inputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_component_names for name in p.required_block_params)
|
||||
]
|
||||
filtered_outputs = [
|
||||
p
|
||||
for p in template_spec.get("outputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_output_names for name in p.required_block_params)
|
||||
]
|
||||
|
||||
filtered_input_names = {p.name for p in filtered_inputs}
|
||||
filtered_model_input_names = {p.name for p in filtered_model_inputs}
|
||||
|
||||
filtered_required_inputs = [
|
||||
r for r in template_spec.get("required_inputs", []) if r in filtered_input_names
|
||||
]
|
||||
filtered_required_model_inputs = [
|
||||
r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names
|
||||
]
|
||||
|
||||
return {
|
||||
"inputs": filtered_inputs,
|
||||
"model_inputs": filtered_model_inputs,
|
||||
"outputs": filtered_outputs,
|
||||
"required_inputs": filtered_required_inputs,
|
||||
"required_model_inputs": filtered_required_model_inputs,
|
||||
"block_name": template_spec.get("block_name"),
|
||||
}
|
||||
|
||||
# Build node specs
|
||||
node_specs = {}
|
||||
for node_type, template_spec in template.items():
|
||||
if template_spec is None:
|
||||
node_specs[node_type] = None
|
||||
continue
|
||||
|
||||
block_name = template_spec.get("block_name")
|
||||
if block_name is None or block_name not in sub_block_map:
|
||||
node_specs[node_type] = None
|
||||
continue
|
||||
|
||||
node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name])
|
||||
|
||||
return cls(
|
||||
node_specs=node_specs,
|
||||
label=label or getattr(blocks, "model_name", ""),
|
||||
default_repo=default_repo,
|
||||
default_dtype=default_dtype,
|
||||
)
|
||||
|
||||
@@ -84,7 +84,7 @@ class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_encoder"]
|
||||
block_names = ["image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -142,7 +142,7 @@ class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -203,7 +203,7 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
@@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
@@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_encoder",
|
||||
"vae_image_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
@@ -384,7 +384,7 @@ IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
@@ -401,7 +401,7 @@ FLF2V_BLOCKS = InsertableDict(
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
@@ -416,7 +416,7 @@ AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
@@ -438,7 +438,7 @@ TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
@@ -450,7 +450,7 @@ IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
|
||||
@@ -15,7 +15,6 @@ from ..utils import (
|
||||
is_torch_available,
|
||||
is_torch_npu_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -155,7 +154,7 @@ else:
|
||||
"AudioLDM2UNet2DConditionModel",
|
||||
]
|
||||
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline", "ChromaInpaintPipeline"]
|
||||
_import_structure["chroma"] = ["ChromaPipeline", "ChromaImg2ImgPipeline"]
|
||||
_import_structure["cogvideo"] = [
|
||||
"CogVideoXPipeline",
|
||||
"CogVideoXImageToVideoPipeline",
|
||||
@@ -435,8 +434,6 @@ else:
|
||||
"QwenImageLayeredPipeline",
|
||||
]
|
||||
_import_structure["chronoedit"] = ["ChronoEditPipeline"]
|
||||
_import_structure["glm_image"] = ["GlmImagePipeline"]
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -598,7 +595,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .blip_diffusion import BlipDiffusionPipeline
|
||||
from .bria import BriaPipeline
|
||||
from .bria_fibo import BriaFiboPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaInpaintPipeline, ChromaPipeline
|
||||
from .chroma import ChromaImg2ImgPipeline, ChromaPipeline
|
||||
from .chronoedit import ChronoEditPipeline
|
||||
from .cogvideo import (
|
||||
CogVideoXFunControlPipeline,
|
||||
@@ -679,7 +676,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
|
||||
@@ -52,7 +52,6 @@ from .flux import (
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -169,7 +168,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("chroma", ChromaPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("glm_image", GlmImagePipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
|
||||
@@ -24,7 +24,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_chroma"] = ["ChromaPipeline"]
|
||||
_import_structure["pipeline_chroma_img2img"] = ["ChromaImg2ImgPipeline"]
|
||||
_import_structure["pipeline_chroma_inpainting"] = ["ChromaInpaintPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -34,7 +33,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_chroma import ChromaPipeline
|
||||
from .pipeline_chroma_img2img import ChromaImg2ImgPipeline
|
||||
from .pipeline_chroma_inpainting import ChromaInpaintPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,59 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
|
||||
|
||||
# Import transformers components so they can be resolved during pipeline loading
|
||||
|
||||
if is_transformers_available() and is_transformers_version(">=", "4.57.4"):
|
||||
try:
|
||||
from transformers import GlmImageForConditionalGeneration, GlmImageProcessor
|
||||
|
||||
_additional_imports["GlmImageForConditionalGeneration"] = GlmImageForConditionalGeneration
|
||||
_additional_imports["GlmImageProcessor"] = GlmImageProcessor
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_glm_image import GlmImagePipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,825 +0,0 @@
|
||||
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, PreTrainedModel, ProcessorMixin, T5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, GlmImageTransformer2DModel
|
||||
from ...models.transformers.transformer_glm_image import GlmImageKVCache
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, is_transformers_version, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import GlmImagePipelineOutput
|
||||
|
||||
|
||||
# Because it's not released in stable as of 13/01/2026. So this is just a proxy.
|
||||
GlmImageProcessor = ProcessorMixin
|
||||
GlmImageForConditionalGeneration = PreTrainedModel
|
||||
if is_transformers_version(">=", "5.0.0.dev0"):
|
||||
from transformers import GlmImageForConditionalGeneration, GlmImageProcessor
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import GlmImagePipeline
|
||||
|
||||
>>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A photo of an astronaut riding a horse on mars"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
base_shift: float = 0.25,
|
||||
max_shift: float = 0.75,
|
||||
) -> float:
|
||||
m = (image_seq_len / base_seq_len) ** 0.5
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
if not accepts_timesteps and not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is not None and sigmas is None:
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is None and sigmas is not None:
|
||||
if not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class GlmImagePipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using GLM-Image.
|
||||
|
||||
This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
|
||||
transformer) model for image decoding.
|
||||
|
||||
Args:
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer for the text encoder.
|
||||
processor (`AutoProcessor`):
|
||||
Processor for the AR model to handle chat templates and tokenization.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder for glyph embeddings.
|
||||
vision_language_encoder ([`GlmImageForConditionalGeneration`]):
|
||||
The AR model that generates image tokens from text prompts.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
transformer ([`GlmImageTransformer2DModel`]):
|
||||
A text conditioned transformer to denoise the encoded image latents (DiT).
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "vision_language_encoder->text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: ByT5Tokenizer,
|
||||
processor: GlmImageProcessor,
|
||||
text_encoder: T5EncoderModel,
|
||||
vision_language_encoder: GlmImageForConditionalGeneration,
|
||||
vae: AutoencoderKL,
|
||||
transformer: GlmImageTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
text_encoder=text_encoder,
|
||||
vision_language_encoder=vision_language_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer")
|
||||
and self.transformer is not None
|
||||
and hasattr(self.transformer.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _compute_generation_params(
|
||||
image_grid_thw,
|
||||
is_text_to_image: bool,
|
||||
):
|
||||
grid_sizes = []
|
||||
grid_hw = []
|
||||
|
||||
for i in range(image_grid_thw.shape[0]):
|
||||
t, h, w = image_grid_thw[i].tolist()
|
||||
grid_sizes.append(int(h * w))
|
||||
grid_hw.append((int(h), int(w)))
|
||||
|
||||
if not is_text_to_image:
|
||||
max_new_tokens = grid_sizes[-1] + 1
|
||||
large_image_start_offset = 0
|
||||
target_grid_h, target_grid_w = grid_hw[-1]
|
||||
else:
|
||||
total_tokens = sum(grid_sizes)
|
||||
max_new_tokens = total_tokens + 1
|
||||
large_image_start_offset = sum(grid_sizes[1:])
|
||||
target_grid_h, target_grid_w = grid_hw[0]
|
||||
return max_new_tokens, large_image_start_offset, target_grid_h, target_grid_w
|
||||
|
||||
@staticmethod
|
||||
def _extract_large_image_tokens(
|
||||
outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int
|
||||
) -> torch.Tensor:
|
||||
generated_tokens = outputs[0][input_length:]
|
||||
large_image_start = large_image_start_offset
|
||||
large_image_end = large_image_start + large_image_tokens
|
||||
return generated_tokens[large_image_start:large_image_end]
|
||||
|
||||
@staticmethod
|
||||
def _upsample_token_ids(token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
|
||||
token_ids = token_ids.view(1, 1, token_h, token_w)
|
||||
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
token_ids = token_ids.view(1, -1)
|
||||
return token_ids
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
target_h=height,
|
||||
target_w=width,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
|
||||
image_grid_thw = inputs.get("image_grid_thw")
|
||||
max_new_tokens, large_image_offset, token_h, token_w = self._compute_generation_params(
|
||||
image_grid_thw=image_grid_thw, is_text_to_image=is_text_to_image
|
||||
)
|
||||
|
||||
prior_token_image_ids = None
|
||||
if image is not None:
|
||||
prior_token_image_embed = self.vision_language_encoder.get_image_features(
|
||||
inputs["pixel_values"], image_grid_thw[:-1]
|
||||
)
|
||||
prior_token_image_embed = torch.cat(prior_token_image_embed, dim=0)
|
||||
prior_token_image_ids = self.vision_language_encoder.get_image_tokens(
|
||||
prior_token_image_embed, image_grid_thw[:-1]
|
||||
)
|
||||
|
||||
# For GLM-Image, greedy decoding is not allowed; it may cause repetitive outputs.
|
||||
# max_new_tokens must be exactly grid_h * grid_w + 1 (the +1 is for EOS).
|
||||
outputs = self.vision_language_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs, inputs["input_ids"].shape[-1], large_image_offset, token_h * token_w
|
||||
)
|
||||
prior_token_ids = self._upsample_token_ids(prior_token_ids_d32, token_h, token_w)
|
||||
|
||||
return prior_token_ids, prior_token_image_ids
|
||||
|
||||
def get_glyph_texts(self, prompt):
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", prompt)
|
||||
+ re.findall(r"“([^“”]*)”", prompt)
|
||||
+ re.findall(r'"([^"]*)"', prompt)
|
||||
+ re.findall(r"「([^「」]*)」", prompt)
|
||||
)
|
||||
return ocr_texts
|
||||
|
||||
def _get_glyph_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 2048,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts if len(glyph_texts) > 0 else [""],
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 2048,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
max_sequence_length (`int`, defaults to `2048`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# For GLM-Image, negative_prompt must be "" instead of None
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
if latents is not None:
|
||||
return latents.to(device)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prior_token_ids=None,
|
||||
prior_image_token_ids=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * self.transformer.config.patch_size * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.transformer.config.patch_size * 2) != 0
|
||||
):
|
||||
# GLM-Image uses 32× downsampling, so the image dimensions must be multiples of 32.
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 4} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
if prompt is not None and prior_token_ids is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prior_token_ids`: {prior_token_ids}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prior_token_ids is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prior_token_ids`. Cannot leave both `prompt` and `prior_token_ids` undefined."
|
||||
)
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if (prior_token_ids is None and prior_image_token_ids is not None) or (
|
||||
prior_token_ids is not None and prior_image_token_ids is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot forward only one `prior_token_ids`: {prior_token_ids} or `prior_image_token_ids`:"
|
||||
f" {prior_image_token_ids} provided. Please make sure both are provided or neither."
|
||||
)
|
||||
|
||||
if prior_token_ids is not None and prompt_embeds is None:
|
||||
raise ValueError("`prompt_embeds` must also be provided with `prior_token_ids`.")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
image: Optional[
|
||||
Union[
|
||||
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
|
||||
]
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 1.5,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prior_token_ids: Optional[torch.FloatTensor] = None,
|
||||
prior_image_token_ids: Optional[torch.Tensor] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 2048,
|
||||
) -> Union[GlmImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
|
||||
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
|
||||
generates a 1152x768 image.
|
||||
image: Optional condition images for image-to-image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels. If not provided, derived from prompt shape info.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels. If not provided, derived from prompt shape info.
|
||||
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||
The number of denoising steps for DiT.
|
||||
guidance_scale (`float`, *optional*, defaults to `1.5`):
|
||||
Guidance scale for classifier-free guidance.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random generator for reproducibility.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: "pil", "np", or "latent".
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`GlmImagePipelineOutput`] or `tuple`: Generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prior_token_ids,
|
||||
prior_image_token_ids,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 due to AR model limitations, got {batch_size}")
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 2. Preprocess image tokens and prompt tokens
|
||||
if prior_token_ids is None:
|
||||
prior_token_ids, prior_token_image_ids = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=image,
|
||||
height=height,
|
||||
width=width,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# 3. Preprocess image
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
image = preprocessed_condition_images
|
||||
|
||||
# 5. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# 4. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=latent_channels,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if image is not None:
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.latent_channels, 1, 1)
|
||||
|
||||
latents_mean = latents_mean.to(device=device, dtype=prompt_embeds.dtype)
|
||||
latents_std = latents_std.to(device=device, dtype=prompt_embeds.dtype)
|
||||
|
||||
for condition_image, condition_image_prior_token_id in zip(image, prior_token_image_ids):
|
||||
condition_image = condition_image.to(device=device, dtype=prompt_embeds.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
|
||||
# Do not remove.
|
||||
# It would be use to run the reference image through a
|
||||
# forward pass at timestep 0 and keep the KV cache.
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=torch.zeros_like(prompt_embeds)[:1, :0, ...],
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
# 6. Prepare additional timestep conditions
|
||||
target_size = (height, width)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
|
||||
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# Prepare timesteps
|
||||
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
|
||||
self.transformer.config.patch_size**2
|
||||
)
|
||||
timesteps = (
|
||||
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
|
||||
if timesteps is None
|
||||
else np.array(timesteps)
|
||||
)
|
||||
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
||||
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("base_shift", 0.25),
|
||||
self.scheduler.config.get("max_shift", 0.75),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
prior_token_drop_cond = torch.full_like(prior_token_ids, False, dtype=torch.bool)
|
||||
prior_token_drop_uncond = torch.full_like(prior_token_ids, True, dtype=torch.bool)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if image is not None:
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
prior_token_id=prior_token_ids,
|
||||
prior_token_drop=prior_token_drop_cond,
|
||||
timestep=timestep,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if image is not None:
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
prior_token_id=prior_token_ids,
|
||||
prior_token_drop=prior_token_drop_uncond,
|
||||
timestep=timestep,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
kv_caches.clear()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return GlmImagePipelineOutput(images=image)
|
||||
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlmImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for CogView3 pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -260,10 +260,10 @@ class LongCatImagePipeline(DiffusionPipeline, FromSingleFileMixin):
|
||||
text = self.text_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
|
||||
all_text.append(text)
|
||||
|
||||
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(self.text_encoder.device)
|
||||
inputs = self.text_processor(text=all_text, padding=True, return_tensors="pt").to(device)
|
||||
|
||||
self.text_encoder.to(device)
|
||||
generated_ids = self.text_encoder.generate(**inputs, max_new_tokens=self.tokenizer_max_length)
|
||||
generated_ids.to(device)
|
||||
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||
output_text = self.text_processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
|
||||
@@ -758,7 +758,6 @@ def load_sub_model(
|
||||
use_safetensors: bool,
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]],
|
||||
provider_options: Any,
|
||||
disable_mmap: bool,
|
||||
quantization_config: Optional[Any] = None,
|
||||
):
|
||||
"""Helper method to load the module `name` from `library_name` and `class_name`"""
|
||||
@@ -860,9 +859,6 @@ def load_sub_model(
|
||||
else:
|
||||
loading_kwargs["low_cpu_mem_usage"] = False
|
||||
|
||||
if is_diffusers_model:
|
||||
loading_kwargs["disable_mmap"] = disable_mmap
|
||||
|
||||
if is_transformers_model and is_transformers_version(">=", "4.57.0"):
|
||||
loading_kwargs.pop("offload_state_dict")
|
||||
|
||||
|
||||
@@ -60,7 +60,6 @@ from ..utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_bitsandbytes_version,
|
||||
is_hpu_available,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
@@ -445,10 +444,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
||||
|
||||
# https://github.com/huggingface/accelerate/pull/3907
|
||||
if is_loaded_in_8bit_bnb and (
|
||||
is_bitsandbytes_version("<", "0.48.0") or is_accelerate_version("<", "1.13.0.dev0")
|
||||
):
|
||||
if is_loaded_in_8bit_bnb:
|
||||
return False
|
||||
|
||||
return hasattr(module, "_hf_hook") and (
|
||||
@@ -527,10 +523,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
||||
)
|
||||
|
||||
if is_loaded_in_8bit_bnb and device is not None and is_bitsandbytes_version("<", "0.48.0"):
|
||||
if is_loaded_in_8bit_bnb and device is not None:
|
||||
logger.warning(
|
||||
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
||||
"You need to upgrade bitsandbytes to at least 0.48.0"
|
||||
)
|
||||
|
||||
# Note: we also handle this at the ModelMixin level. The reason for doing it here too is that modeling
|
||||
@@ -547,14 +542,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
||||
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
||||
module.to(device=device)
|
||||
# added here https://github.com/huggingface/transformers/pull/43258
|
||||
if (
|
||||
is_loaded_in_8bit_bnb
|
||||
and device is not None
|
||||
and is_transformers_version(">", "4.58.0")
|
||||
and is_bitsandbytes_version(">=", "0.48.0")
|
||||
):
|
||||
module.to(device=device)
|
||||
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb and not is_group_offloaded:
|
||||
module.to(device, dtype)
|
||||
|
||||
@@ -721,9 +708,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
loading `from_flax`.
|
||||
dduf_file(`str`, *optional*):
|
||||
Load weights from the specified dduf file.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
|
||||
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
|
||||
with `hf > auth login`.
|
||||
@@ -775,7 +759,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, dict) and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
@@ -1063,7 +1046,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_safetensors=use_safetensors,
|
||||
dduf_entries=dduf_entries,
|
||||
provider_options=provider_options,
|
||||
disable_mmap=disable_mmap,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
logger.info(
|
||||
@@ -1241,9 +1223,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
# This is because the model would already be placed on a CUDA device.
|
||||
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
|
||||
if is_loaded_in_8bit_bnb and (
|
||||
is_transformers_version("<", "4.58.0") or is_bitsandbytes_version("<", "0.48.0")
|
||||
):
|
||||
if is_loaded_in_8bit_bnb:
|
||||
logger.info(
|
||||
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
|
||||
)
|
||||
|
||||
@@ -982,21 +982,6 @@ class FluxTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class GlmImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -632,21 +632,6 @@ class ChromaImg2ImgPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class ChromaPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
@@ -1157,21 +1142,6 @@ class FluxPriorReduxPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class GlmImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class HiDreamImagePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -1,227 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, T5EncoderModel
|
||||
|
||||
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, GlmImagePipeline, GlmImageTransformer2DModel
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ...testing_utils import enable_full_determinism, require_torch_accelerator, require_transformers_version_greater
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
if is_transformers_version(">=", "5.0.0.dev0"):
|
||||
from transformers import GlmImageConfig, GlmImageForConditionalGeneration, GlmImageProcessor
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@require_transformers_version_greater("4.57.4")
|
||||
@require_torch_accelerator
|
||||
class GlmImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = GlmImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_xformers_attention = False
|
||||
test_attention_slicing = False
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
|
||||
|
||||
glm_config = GlmImageConfig(
|
||||
text_config={
|
||||
"vocab_size": 168064,
|
||||
"hidden_size": 32,
|
||||
"intermediate_size": 32,
|
||||
"num_hidden_layers": 2,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 2,
|
||||
"max_position_embeddings": 512,
|
||||
"vision_vocab_size": 128,
|
||||
"rope_parameters": {"mrope_section": (4, 2, 2)},
|
||||
},
|
||||
vision_config={
|
||||
"depth": 2,
|
||||
"hidden_size": 32,
|
||||
"num_heads": 2,
|
||||
"image_size": 32,
|
||||
"patch_size": 8,
|
||||
"intermediate_size": 32,
|
||||
},
|
||||
vq_config={"embed_dim": 32, "num_embeddings": 128, "latent_channels": 32},
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vision_language_encoder = GlmImageForConditionalGeneration(glm_config)
|
||||
|
||||
processor = GlmImageProcessor.from_pretrained("zai-org/GLM-Image", subfolder="processor")
|
||||
|
||||
torch.manual_seed(0)
|
||||
# For GLM-Image, the relationship between components must satisfy:
|
||||
# patch_size × vae_scale_factor = 16 (since AR tokens are upsampled 2× from d32)
|
||||
transformer = GlmImageTransformer2DModel(
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
num_layers=2,
|
||||
attention_head_dim=8,
|
||||
num_attention_heads=2,
|
||||
text_embed_dim=text_encoder.config.hidden_size,
|
||||
time_embed_dim=16,
|
||||
condition_dim=8,
|
||||
prior_vq_quantizer_codebook_size=128,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=(4, 8, 16, 16),
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
norm_num_groups=4,
|
||||
sample_size=128,
|
||||
latents_mean=[0.0] * 4,
|
||||
latents_std=[1.0] * 4,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"tokenizer": tokenizer,
|
||||
"processor": processor,
|
||||
"text_encoder": text_encoder,
|
||||
"vision_language_encoder": vision_language_encoder,
|
||||
"vae": vae,
|
||||
"transformer": transformer,
|
||||
"scheduler": scheduler,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
height, width = 32, 32
|
||||
|
||||
inputs = {
|
||||
"prompt": "A photo of a cat",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.5,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images[0]
|
||||
generated_slice = image.flatten()
|
||||
generated_slice = np.concatenate([generated_slice[:8], generated_slice[-8:]])
|
||||
|
||||
# fmt: off
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.5796329, 0.5005878, 0.45881274, 0.45331675, 0.43688118, 0.4899527, 0.54017603, 0.50983673, 0.3387968, 0.38074082, 0.29942477, 0.33733928, 0.3672544, 0.38462338, 0.40991822, 0.46641728
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
self.assertEqual(image.shape, (3, 32, 32))
|
||||
self.assertTrue(np.allclose(expected_slice, generated_slice, atol=1e-4, rtol=1e-4))
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_single_identical(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_inference_batch_consistent(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Not supported.")
|
||||
def test_num_images_per_prompt(self):
|
||||
# GLM-Image has batch_size=1 constraint due to AR model
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Needs to be revisited.")
|
||||
def test_pipeline_level_group_offloading_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Follow set of tests are relaxed because this pipeline doesn't guarantee same outputs for the same inputs in consecutive runs."
|
||||
)
|
||||
def test_dict_tuple_outputs_equivalent(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_cpu_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_sequential_offload_forward_pass_twice(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_float16_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_save_load_float16(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Skipped")
|
||||
def test_save_load_local(self):
|
||||
pass
|
||||
@@ -288,29 +288,31 @@ class BnB8bitBasicTests(Base8bitTests):
|
||||
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)
|
||||
self.assertTrue(hasattr(linear.weight, "SCB"))
|
||||
|
||||
@require_bitsandbytes_version_greater("0.48.0")
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with `str`
|
||||
self.model_8bit.to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
self.model_8bit.to(torch.float16)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `device`
|
||||
self.model_8bit.float()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype`
|
||||
# Tries with a `device`
|
||||
self.model_8bit.half()
|
||||
|
||||
# This should work with 0.48.0
|
||||
self.model_8bit.to("cpu")
|
||||
self.model_8bit.to(torch.device(f"{torch_device}:0"))
|
||||
|
||||
# Test if we did not break anything
|
||||
self.model_fp16 = self.model_fp16.to(dtype=torch.float32, device=torch_device)
|
||||
input_dict_for_transformer = self.get_dummy_inputs()
|
||||
@@ -835,7 +837,7 @@ class BaseBnb8bitSerializationTests(Base8bitTests):
|
||||
|
||||
|
||||
@require_torch_version_greater_equal("2.6.0")
|
||||
@require_bitsandbytes_version_greater("0.48.0")
|
||||
@require_bitsandbytes_version_greater("0.45.5")
|
||||
class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
@property
|
||||
def quantization_config(self):
|
||||
@@ -846,7 +848,7 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
)
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Test fails because of a type change when recompiling."
|
||||
reason="Test fails because of an offloading problem from Accelerate with confusion in hooks."
|
||||
" Test passes without recompilation context manager. Refer to https://github.com/huggingface/diffusers/pull/12002/files#r2240462757 for details."
|
||||
)
|
||||
def test_torch_compile(self):
|
||||
@@ -856,5 +858,6 @@ class Bnb8BitCompileTests(QuantCompileTests, unittest.TestCase):
|
||||
def test_torch_compile_with_cpu_offload(self):
|
||||
super()._test_torch_compile_with_cpu_offload(torch_dtype=torch.float16)
|
||||
|
||||
@pytest.mark.xfail(reason="Test fails because of an offloading problem from Accelerate with confusion in hooks.")
|
||||
def test_torch_compile_with_group_offload_leaf(self):
|
||||
super()._test_torch_compile_with_group_offload_leaf(torch_dtype=torch.float16, use_stream=True)
|
||||
|
||||
Reference in New Issue
Block a user